Spaces:
Running on Zero
Running on Zero
| """ | |
| Steerling-8B Demo β Hugging Face Spaces (ZeroGPU) | |
| An interactive demo for Steerling, an interpretable causal diffusion language model | |
| with concept steering. Uses confidence-based unmasking to generate text. | |
| Tokens are streamed live so you can watch the diffusion process fill in words | |
| out of order β the signature behavior of this model. | |
| https://huggingface.co/guidelabs/steerling-8b | |
| """ | |
| from __future__ import annotations | |
| # --------------------------------------------------------------------------- | |
| # Install steerling at startup β its metadata says >=3.13 but the code works | |
| # fine on 3.12 (every module uses `from __future__ import annotations`). | |
| # --------------------------------------------------------------------------- | |
| import subprocess | |
| import sys | |
| subprocess.check_call( | |
| [ | |
| sys.executable, | |
| "-m", | |
| "pip", | |
| "install", | |
| "--quiet", | |
| "--no-deps", | |
| "--ignore-requires-python", | |
| "steerling>=0.1.2", | |
| ] | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # Imports | |
| # --------------------------------------------------------------------------- | |
| import html as html_lib | |
| import logging | |
| import math | |
| import time | |
| from textwrap import dedent | |
| import gradio as gr | |
| import spaces | |
| import torch | |
| from steerling import SteerlingGenerator | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # --------------------------------------------------------------------------- | |
| # Model loading | |
| # --------------------------------------------------------------------------- | |
| MODEL_ID = "guidelabs/steerling-8b" | |
| logger.info("Loading Steerling-8B model β¦") | |
| generator = SteerlingGenerator.from_pretrained(MODEL_ID, device="cuda") | |
| logger.info("Model loaded successfully.") | |
| # --------------------------------------------------------------------------- | |
| # Constants | |
| # --------------------------------------------------------------------------- | |
| MASK_CHAR = "β" # visual placeholder for masked positions | |
| STEERING_PRESETS: dict[str, dict[str, dict[int, float] | None]] = { | |
| "None": {"steer_known": None, "steer_unknown": None}, | |
| "More formal / academic": { | |
| "steer_known": {102: 2.5, 3421: 2.0, 8910: -1.5}, | |
| "steer_unknown": None, | |
| }, | |
| "More creative / poetic": { | |
| "steer_known": {541: 2.0, 7723: 2.5, 102: -2.0}, | |
| "steer_unknown": None, | |
| }, | |
| "More concise / factual": { | |
| "steer_known": {102: 1.5, 3421: 1.5, 541: -2.0, 7723: -1.5}, | |
| "steer_unknown": None, | |
| }, | |
| } | |
| EXAMPLE_PROMPTS = [ | |
| "The key to understanding neural networks is", | |
| "In the year 2050, renewable energy", | |
| "The theory of general relativity explains", | |
| "Once upon a time, in a land where machines could dream,", | |
| "The most important breakthrough in modern medicine was", | |
| "Artificial intelligence will", | |
| "The future of space exploration depends on", | |
| "A poem about the ocean:\n", | |
| ] | |
| # --------------------------------------------------------------------------- | |
| # HTML rendering helpers | |
| # --------------------------------------------------------------------------- | |
| def _render_html( | |
| prompt_text: str, | |
| gen_tokens: list[int], | |
| is_finalized: list[bool], | |
| just_unmasked: set[int], | |
| tokenizer, | |
| total_gen_slots: int, | |
| step: int, | |
| elapsed: float, | |
| ) -> str: | |
| """Build an HTML snippet showing prompt + generation with color-coded tokens.""" | |
| # Escape prompt for safe HTML | |
| escaped_prompt = html_lib.escape(prompt_text) | |
| parts: list[str] = [] | |
| for i in range(total_gen_slots): | |
| if i < len(gen_tokens) and is_finalized[i]: | |
| tok_text = tokenizer.decode([gen_tokens[i]]) | |
| escaped = html_lib.escape(tok_text) | |
| # Replace newlines with <br> and preserve spaces | |
| escaped = escaped.replace("\n", "<br>") | |
| escaped = escaped.replace(" ", " ") | |
| if i in just_unmasked: | |
| # Newly revealed this step β bright highlight | |
| parts.append( | |
| f'<span style="background:#ffe066;color:#1a1a2e;' | |
| f"border-radius:3px;padding:0 1px;font-weight:600;" | |
| f'transition:background 0.6s ease;">' | |
| f"{escaped}</span>" | |
| ) | |
| else: | |
| # Previously revealed | |
| parts.append(f'<span style="color:#e0e0e0;">{escaped}</span>') | |
| else: | |
| # Still masked | |
| parts.append( | |
| f'<span style="color:#555;font-family:monospace;">{MASK_CHAR}</span>' | |
| ) | |
| gen_html = "".join(parts) | |
| n_filled = sum(1 for f in is_finalized if f) | |
| n_total = total_gen_slots | |
| pct = int(100 * n_filled / n_total) if n_total > 0 else 0 | |
| # Progress bar | |
| bar_html = ( | |
| f'<div style="margin:12px 0 8px 0;background:#23233a;border-radius:6px;' | |
| f'height:10px;width:100%;overflow:hidden;">' | |
| f'<div style="width:{pct}%;height:100%;background:linear-gradient(90deg,#ffe066,#ff6f61);' | |
| f'border-radius:6px;transition:width 0.3s ease;"></div></div>' | |
| f'<div style="font-size:0.82em;color:#888;margin-bottom:6px;">' | |
| f"Step {step} — {n_filled}/{n_total} tokens unmasked ({pct}%)" | |
| f" — {elapsed:.1f}s elapsed</div>" | |
| ) | |
| return ( | |
| f"<div style=\"font-family:'Inter',system-ui,sans-serif;font-size:1.05em;" | |
| f"line-height:1.7;padding:16px 20px;background:#1a1a2e;color:#e0e0e0;" | |
| f'border-radius:10px;white-space:pre-wrap;word-wrap:break-word;">' | |
| f'<span style="color:#82aaff;font-weight:600;">{escaped_prompt}</span>' | |
| f"{gen_html}" | |
| f"{bar_html}" | |
| f"</div>" | |
| ) | |
| def _render_final_info( | |
| prompt_tokens: int, | |
| generated_tokens: int, | |
| total_steps: int, | |
| elapsed: float, | |
| steering_preset: str, | |
| steer_known: dict[int, float] | None, | |
| ) -> str: | |
| tok_per_sec = generated_tokens / elapsed if elapsed > 0 else 0.0 | |
| info = ( | |
| f"**Prompt tokens:** {prompt_tokens} \n" | |
| f"**Generated tokens:** {generated_tokens} \n" | |
| f"**Diffusion steps:** {total_steps} \n" | |
| f"**Time:** {elapsed:.2f}s ({tok_per_sec:.1f} tok/s) \n" | |
| f"**Steering:** {steering_preset}" | |
| ) | |
| if steer_known: | |
| info += f" \n**Known concept IDs:** `{steer_known}`" | |
| return info | |
| # --------------------------------------------------------------------------- | |
| # Streaming generation β reimplements the core loop from | |
| # SteerlingGenerator.generate_full so we can yield after each step. | |
| # --------------------------------------------------------------------------- | |
| def generate_streaming( | |
| prompt: str, | |
| max_new_tokens: int, | |
| temperature: float, | |
| top_p: float, | |
| repetition_penalty: float, | |
| tokens_per_step: int, | |
| use_entropy_sampling: bool, | |
| seed: int | None, | |
| steering_preset: str, | |
| custom_steer_known: str, | |
| ): | |
| """Generator: yields (html_viz, generated_text, info_md) tuples.""" | |
| if not prompt or not prompt.strip(): | |
| yield ( | |
| '<div style="padding:16px;color:#ff6f61;">β οΈ Please enter a prompt.</div>', | |
| "", | |
| "", | |
| ) | |
| return | |
| # --- resolve steering --------------------------------------------------- | |
| steer_known: dict[int, float] | None = None | |
| steer_unknown: dict[int, float] | None = None | |
| if steering_preset != "None" and steering_preset in STEERING_PRESETS: | |
| preset = STEERING_PRESETS[steering_preset] | |
| steer_known = preset.get("steer_known") | |
| steer_unknown = preset.get("steer_unknown") | |
| if custom_steer_known and custom_steer_known.strip(): | |
| try: | |
| parsed: dict[int, float] = {} | |
| for pair in custom_steer_known.split(","): | |
| pair = pair.strip() | |
| if not pair: | |
| continue | |
| cid, val = pair.split(":") | |
| parsed[int(cid.strip())] = float(val.strip()) | |
| if parsed: | |
| steer_known = parsed | |
| except Exception as exc: | |
| yield ( | |
| f'<div style="padding:16px;color:#ff6f61;">β οΈ Could not parse custom steering: {exc}</div>', | |
| "", | |
| "", | |
| ) | |
| return | |
| # --- setup -------------------------------------------------------------- | |
| gen = generator # alias | |
| tokenizer = gen.tokenizer | |
| model = gen.model | |
| device = gen.device | |
| mask_id = gen.mask_token_id | |
| eos_id = gen.eos_token_id | |
| pad_id = gen.pad_token_id | |
| if seed is not None and seed >= 0: | |
| torch.manual_seed(int(seed)) | |
| prompt_ids = tokenizer.encode(prompt, add_special_tokens=False) | |
| prompt_len = len(prompt_ids) | |
| total_len = prompt_len + max_new_tokens | |
| # Initialize sequence: prompt tokens + mask tokens for generation slots | |
| x = torch.full((1, total_len), mask_id, dtype=torch.long, device=device) | |
| if prompt_len > 0: | |
| x[0, :prompt_len] = torch.tensor(prompt_ids, dtype=torch.long, device=device) | |
| # Track what's finalized | |
| is_prompt_mask = torch.zeros(total_len, dtype=torch.bool, device=device) | |
| is_prompt_mask[:prompt_len] = True | |
| gen_region = ~is_prompt_mask | |
| is_finalized = is_prompt_mask.clone() | |
| # Banned token ids | |
| banned_ids = {mask_id} | |
| if pad_id is not None: | |
| banned_ids.add(pad_id) | |
| # Build steering intervention tensors | |
| int_known_ids, int_known_vals = None, None | |
| int_unknown_ids, int_unknown_vals = None, None | |
| if gen.is_interpretable and steer_known: | |
| int_known_ids, int_known_vals = gen._build_intervention_tensors( | |
| steer_known, total_len | |
| ) | |
| if gen.is_interpretable and steer_unknown: | |
| int_unknown_ids, int_unknown_vals = gen._build_intervention_tensors( | |
| steer_unknown, total_len | |
| ) | |
| # Lists to track generation-region state for rendering | |
| gen_token_ids: list[int] = [mask_id] * max_new_tokens | |
| gen_finalized: list[bool] = [False] * max_new_tokens | |
| t0 = time.perf_counter() | |
| tokens_generated = 0 | |
| step_count = 0 | |
| # --- yield initial state (all masked) ----------------------------------- | |
| yield ( | |
| _render_html( | |
| prompt, | |
| gen_token_ids, | |
| gen_finalized, | |
| set(), | |
| tokenizer, | |
| max_new_tokens, | |
| 0, | |
| 0.0, | |
| ), | |
| "", | |
| "*Generatingβ¦*", | |
| ) | |
| # --- diffusion loop (under inference_mode for perf) --------------------- | |
| with torch.inference_mode(): | |
| while tokens_generated < max_new_tokens: | |
| still_masked = (x[0] == mask_id) & gen_region | |
| masked_indices = still_masked.nonzero(as_tuple=False).squeeze(-1) | |
| if masked_indices.numel() == 0: | |
| break | |
| if masked_indices.dim() == 0: | |
| masked_indices = masked_indices.unsqueeze(0) | |
| # Forward pass | |
| if gen.is_interpretable: | |
| logits, _ = model( | |
| x, | |
| use_teacher_forcing=False, | |
| intervene_known_ids=int_known_ids, | |
| intervene_known_vals=int_known_vals, | |
| intervene_unknown_ids=int_unknown_ids, | |
| intervene_unknown_vals=int_unknown_vals, | |
| minimal_output=True, | |
| ) | |
| else: | |
| logits = model(x) | |
| masked_logits = logits[0, masked_indices].clone() | |
| # Ban special tokens | |
| for tid in banned_ids: | |
| masked_logits[:, tid] = -1e9 | |
| # Repetition penalty | |
| if repetition_penalty != 1.0: | |
| finalized_tokens = x[0, is_finalized].tolist() | |
| for tok in set(finalized_tokens): | |
| if tok not in banned_ids: | |
| masked_logits[:, tok] /= repetition_penalty | |
| # Confidence-based position selection | |
| probs_for_conf = torch.softmax(masked_logits, dim=-1) | |
| confidences = probs_for_conf.max(dim=-1).values | |
| k = min(tokens_per_step, masked_indices.numel()) | |
| _, selected_pos_indices = confidences.topk(k) | |
| step_count += 1 | |
| just_unmasked: set[int] = set() | |
| # Fill selected positions | |
| for pos_idx in selected_pos_indices: | |
| seq_idx = int(masked_indices[pos_idx].item()) | |
| gen_slot = seq_idx - prompt_len # index into gen arrays | |
| pos_logits = masked_logits[pos_idx] | |
| # Temperature (entropy-adaptive or fixed) | |
| if use_entropy_sampling: | |
| pos_probs_raw = torch.softmax(pos_logits, dim=-1) | |
| sorted_probs, _ = torch.sort(pos_probs_raw, descending=True) | |
| cumsum = torch.cumsum(sorted_probs, dim=-1) | |
| effective_k = max((cumsum <= top_p).sum().item() + 1, 2) | |
| entropy = -torch.sum( | |
| pos_probs_raw * torch.log(pos_probs_raw + 1e-10) | |
| ) | |
| normalized_entropy = min( | |
| 1.0, entropy.item() / math.log(effective_k) | |
| ) | |
| adaptive_temp = 0.3 + 0.4 * normalized_entropy | |
| pos_probs = torch.softmax(pos_logits / adaptive_temp, dim=-1) | |
| else: | |
| pos_probs = torch.softmax( | |
| pos_logits / max(temperature, 1e-8), dim=-1 | |
| ) | |
| tok = _sample_top_p(pos_probs, top_p) | |
| x[0, seq_idx] = tok | |
| is_finalized[seq_idx] = True | |
| tokens_generated += 1 | |
| if 0 <= gen_slot < max_new_tokens: | |
| gen_token_ids[gen_slot] = tok | |
| gen_finalized[gen_slot] = True | |
| just_unmasked.add(gen_slot) | |
| if eos_id is not None and tok == eos_id: | |
| break | |
| elapsed = time.perf_counter() - t0 | |
| # Decode the current generated text (finalized tokens only, in order) | |
| current_gen_tokens = [] | |
| for i in range(max_new_tokens): | |
| if gen_finalized[i]: | |
| current_gen_tokens.append(gen_token_ids[i]) | |
| else: | |
| break | |
| current_text = ( | |
| tokenizer.decode(current_gen_tokens) if current_gen_tokens else "" | |
| ) | |
| yield ( | |
| _render_html( | |
| prompt, | |
| gen_token_ids, | |
| gen_finalized, | |
| just_unmasked, | |
| tokenizer, | |
| max_new_tokens, | |
| step_count, | |
| elapsed, | |
| ), | |
| current_text, | |
| f"*Step {step_count} β {tokens_generated}/{max_new_tokens} tokens β {elapsed:.1f}s*", | |
| ) | |
| # Check EOS | |
| if eos_id is not None and (x[0, gen_region] == eos_id).any(): | |
| break | |
| # --- final yield -------------------------------------------------------- | |
| elapsed = time.perf_counter() - t0 | |
| final_tokens = [] | |
| for i in range(max_new_tokens): | |
| if gen_finalized[i]: | |
| final_tokens.append(gen_token_ids[i]) | |
| else: | |
| break | |
| final_text = tokenizer.decode(final_tokens) if final_tokens else "" | |
| final_info = _render_final_info( | |
| prompt_tokens=prompt_len, | |
| generated_tokens=len(final_tokens), | |
| total_steps=step_count, | |
| elapsed=elapsed, | |
| steering_preset=steering_preset, | |
| steer_known=steer_known, | |
| ) | |
| # Final HTML without any highlight | |
| yield ( | |
| _render_html( | |
| prompt, | |
| gen_token_ids, | |
| gen_finalized, | |
| set(), | |
| tokenizer, | |
| max_new_tokens, | |
| step_count, | |
| elapsed, | |
| ), | |
| final_text, | |
| final_info, | |
| ) | |
| def _sample_top_p(probs: torch.Tensor, top_p: float) -> int: | |
| sorted_probs, sorted_indices = torch.sort(probs, descending=True) | |
| cumulative = torch.cumsum(sorted_probs, dim=-1) | |
| cutoff_mask = cumulative <= top_p | |
| cutoff_mask[0] = True | |
| cutoff_idx = min(cutoff_mask.sum().item() + 1, len(sorted_probs)) | |
| truncated = sorted_probs[:cutoff_idx] | |
| truncated = truncated / truncated.sum() | |
| return int(sorted_indices[torch.multinomial(truncated, 1)].item()) | |
| # --------------------------------------------------------------------------- | |
| # Gradio UI | |
| # --------------------------------------------------------------------------- | |
| DESCRIPTION = dedent("""\ | |
| # π§ Steerling-8B Demo | |
| **[Steerling-8B](https://huggingface.co/guidelabs/steerling-8b)** is an | |
| 8 billion parameter *causal diffusion* language model with interpretable | |
| concept steering, built by [Guide Labs](https://www.guidelabs.ai/). | |
| Unlike standard autoregressive LLMs, Steerling generates text by | |
| **iteratively unmasking tokens in order of confidence** β the model fills | |
| in positions where it is most certain first. Watch the diffusion process | |
| live below! | |
| ### β¨ Key Features | |
| | Feature | Description | | |
| |---|---| | |
| | π² **Diffusion decoding** | Confidence-based unmasking instead of left-to-right | | |
| | π **Interpretability** | Hidden states β known + unknown concept decomposition | | |
| | ποΈ **Concept steering** | Amplify or suppress concepts to guide generation | | |
| | π **Block-causal attention** | Bidirectional within 64-token blocks, causal across | | |
| > βΉοΈ This Space runs on **ZeroGPU** (NVIDIA H200). Generation may be | |
| > queued briefly while a GPU is allocated. | |
| """) | |
| ARTICLE = dedent("""\ | |
| --- | |
| ### How It Works | |
| ``` | |
| hidden β known_features + unknown_features + Ξ΅ = composed β logits | |
| ``` | |
| - **known_features** β weighted sum of top-k learned concept embeddings (interpretable) | |
| - **unknown_features** β residual captured by a factorized unknown head | |
| - **Ξ΅** β small correction for reconstruction fidelity | |
| The **live visualization** above shows the diffusion process in action: | |
| - <span style="color:#82aaff;">**Blue text**</span> = your prompt | |
| - <span style="background:#ffe066;color:#1a1a2e;padding:0 3px;border-radius:3px;">**Highlighted**</span> = just unmasked this step | |
| - <span style="color:#555;">β</span> = still masked (waiting to be filled) | |
| Unlike autoregressive models that generate left-to-right, Steerling fills in | |
| the **most confident positions first**, regardless of order. | |
| ### Links | |
| - π [Model Card](https://huggingface.co/guidelabs/steerling-8b) | |
| - π» [GitHub](https://github.com/guidelabs/steerling) | |
| - π’ [Guide Labs](https://www.guidelabs.ai/) | |
| - π [Architecture Blog Post](https://www.guidelabs.ai/post/block-causal-diffusion-language-model/) | |
| """) | |
| CSS = """ | |
| footer { display: none !important; } | |
| .generating { border: none !important; } | |
| """ | |
| with gr.Blocks(css=CSS, title="Steerling-8B Demo", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown(DESCRIPTION) | |
| with gr.Row(): | |
| # ββ Left column: inputs βββββββββββββββββββββββββββββββββββββββ | |
| with gr.Column(scale=1): | |
| prompt = gr.Textbox( | |
| label="Prompt", | |
| placeholder="Enter your prompt hereβ¦", | |
| lines=4, | |
| value=EXAMPLE_PROMPTS[0], | |
| ) | |
| with gr.Accordion("βοΈ Generation Settings", open=False): | |
| max_new_tokens = gr.Slider( | |
| 16, | |
| 512, | |
| value=128, | |
| step=16, | |
| label="Max new tokens", | |
| ) | |
| temperature = gr.Slider( | |
| 0.0, | |
| 2.0, | |
| value=1.0, | |
| step=0.05, | |
| label="Temperature", | |
| info="Overridden when entropy sampling is on", | |
| ) | |
| top_p = gr.Slider( | |
| 0.1, | |
| 1.0, | |
| value=0.9, | |
| step=0.05, | |
| label="Top-p (nucleus)", | |
| ) | |
| repetition_penalty = gr.Slider( | |
| 1.0, | |
| 2.0, | |
| value=1.2, | |
| step=0.05, | |
| label="Repetition penalty", | |
| ) | |
| tokens_per_step = gr.Slider( | |
| 1, | |
| 64, | |
| value=1, | |
| step=1, | |
| label="Tokens per step", | |
| info="Unmask multiple positions per diffusion step (faster but noisier)", | |
| ) | |
| use_entropy_sampling = gr.Checkbox( | |
| value=True, | |
| label="Entropy-adaptive sampling", | |
| info="Automatically adjusts temperature (0.3β0.7) based on model uncertainty", | |
| ) | |
| seed = gr.Number( | |
| value=42, | |
| label="Seed (-1 = random)", | |
| precision=0, | |
| ) | |
| with gr.Accordion("ποΈ Concept Steering", open=False): | |
| gr.Markdown( | |
| "Steerling decomposes hidden states into **known concepts**. " | |
| "You can amplify (positive weight) or suppress (negative weight) " | |
| "specific concept IDs to steer generation." | |
| ) | |
| steering_preset = gr.Dropdown( | |
| choices=list(STEERING_PRESETS.keys()), | |
| value="None", | |
| label="Steering preset", | |
| ) | |
| custom_steer_known = gr.Textbox( | |
| label="Custom known-concept overrides", | |
| placeholder="e.g. 102:2.5, 541:-1.0", | |
| info="Comma-separated id:weight pairs. Overrides the preset.", | |
| ) | |
| generate_btn = gr.Button( | |
| "π Generate", | |
| variant="primary", | |
| size="lg", | |
| ) | |
| gr.Examples( | |
| examples=[[p] for p in EXAMPLE_PROMPTS], | |
| inputs=[prompt], | |
| label="Example prompts", | |
| ) | |
| # ββ Right column: outputs βββββββββββββββββββββββββββββββββββββ | |
| with gr.Column(scale=1): | |
| viz_html = gr.HTML( | |
| label="Live diffusion", | |
| value=( | |
| '<div style="padding:20px;text-align:center;color:#555;' | |
| 'font-style:italic;">Press Generate to watch the diffusion ' | |
| "process unfoldβ¦</div>" | |
| ), | |
| ) | |
| generated_output = gr.Textbox( | |
| label="Generated text (plain)", | |
| lines=6, | |
| interactive=False, | |
| ) | |
| info_md = gr.Markdown(label="Generation info") | |
| # Wire inputs list | |
| inputs = [ | |
| prompt, | |
| max_new_tokens, | |
| temperature, | |
| top_p, | |
| repetition_penalty, | |
| tokens_per_step, | |
| use_entropy_sampling, | |
| seed, | |
| steering_preset, | |
| custom_steer_known, | |
| ] | |
| outputs = [viz_html, generated_output, info_md] | |
| generate_btn.click( | |
| fn=generate_streaming, | |
| inputs=inputs, | |
| outputs=outputs, | |
| ) | |
| prompt.submit( | |
| fn=generate_streaming, | |
| inputs=inputs, | |
| outputs=outputs, | |
| ) | |
| gr.Markdown(ARTICLE) | |
| if __name__ == "__main__": | |
| demo.launch() | |