DLLM-Demo / app.py
Sergidev's picture
v1.3
aef59cd
"""
Steerling-8B Demo β€” Hugging Face Spaces (ZeroGPU)
An interactive demo for Steerling, an interpretable causal diffusion language model
with concept steering. Uses confidence-based unmasking to generate text.
Tokens are streamed live so you can watch the diffusion process fill in words
out of order β€” the signature behavior of this model.
https://huggingface.co/guidelabs/steerling-8b
"""
from __future__ import annotations
# ---------------------------------------------------------------------------
# Install steerling at startup β€” its metadata says >=3.13 but the code works
# fine on 3.12 (every module uses `from __future__ import annotations`).
# ---------------------------------------------------------------------------
import subprocess
import sys
subprocess.check_call(
[
sys.executable,
"-m",
"pip",
"install",
"--quiet",
"--no-deps",
"--ignore-requires-python",
"steerling>=0.1.2",
]
)
# ---------------------------------------------------------------------------
# Imports
# ---------------------------------------------------------------------------
import html as html_lib
import logging
import math
import time
from textwrap import dedent
import gradio as gr
import spaces
import torch
from steerling import SteerlingGenerator
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Model loading
# ---------------------------------------------------------------------------
MODEL_ID = "guidelabs/steerling-8b"
logger.info("Loading Steerling-8B model …")
generator = SteerlingGenerator.from_pretrained(MODEL_ID, device="cuda")
logger.info("Model loaded successfully.")
# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
MASK_CHAR = "β–’" # visual placeholder for masked positions
STEERING_PRESETS: dict[str, dict[str, dict[int, float] | None]] = {
"None": {"steer_known": None, "steer_unknown": None},
"More formal / academic": {
"steer_known": {102: 2.5, 3421: 2.0, 8910: -1.5},
"steer_unknown": None,
},
"More creative / poetic": {
"steer_known": {541: 2.0, 7723: 2.5, 102: -2.0},
"steer_unknown": None,
},
"More concise / factual": {
"steer_known": {102: 1.5, 3421: 1.5, 541: -2.0, 7723: -1.5},
"steer_unknown": None,
},
}
EXAMPLE_PROMPTS = [
"The key to understanding neural networks is",
"In the year 2050, renewable energy",
"The theory of general relativity explains",
"Once upon a time, in a land where machines could dream,",
"The most important breakthrough in modern medicine was",
"Artificial intelligence will",
"The future of space exploration depends on",
"A poem about the ocean:\n",
]
# ---------------------------------------------------------------------------
# HTML rendering helpers
# ---------------------------------------------------------------------------
def _render_html(
prompt_text: str,
gen_tokens: list[int],
is_finalized: list[bool],
just_unmasked: set[int],
tokenizer,
total_gen_slots: int,
step: int,
elapsed: float,
) -> str:
"""Build an HTML snippet showing prompt + generation with color-coded tokens."""
# Escape prompt for safe HTML
escaped_prompt = html_lib.escape(prompt_text)
parts: list[str] = []
for i in range(total_gen_slots):
if i < len(gen_tokens) and is_finalized[i]:
tok_text = tokenizer.decode([gen_tokens[i]])
escaped = html_lib.escape(tok_text)
# Replace newlines with <br> and preserve spaces
escaped = escaped.replace("\n", "<br>")
escaped = escaped.replace(" ", " &nbsp;")
if i in just_unmasked:
# Newly revealed this step β€” bright highlight
parts.append(
f'<span style="background:#ffe066;color:#1a1a2e;'
f"border-radius:3px;padding:0 1px;font-weight:600;"
f'transition:background 0.6s ease;">'
f"{escaped}</span>"
)
else:
# Previously revealed
parts.append(f'<span style="color:#e0e0e0;">{escaped}</span>')
else:
# Still masked
parts.append(
f'<span style="color:#555;font-family:monospace;">{MASK_CHAR}</span>'
)
gen_html = "".join(parts)
n_filled = sum(1 for f in is_finalized if f)
n_total = total_gen_slots
pct = int(100 * n_filled / n_total) if n_total > 0 else 0
# Progress bar
bar_html = (
f'<div style="margin:12px 0 8px 0;background:#23233a;border-radius:6px;'
f'height:10px;width:100%;overflow:hidden;">'
f'<div style="width:{pct}%;height:100%;background:linear-gradient(90deg,#ffe066,#ff6f61);'
f'border-radius:6px;transition:width 0.3s ease;"></div></div>'
f'<div style="font-size:0.82em;color:#888;margin-bottom:6px;">'
f"Step {step} &mdash; {n_filled}/{n_total} tokens unmasked ({pct}%)"
f" &mdash; {elapsed:.1f}s elapsed</div>"
)
return (
f"<div style=\"font-family:'Inter',system-ui,sans-serif;font-size:1.05em;"
f"line-height:1.7;padding:16px 20px;background:#1a1a2e;color:#e0e0e0;"
f'border-radius:10px;white-space:pre-wrap;word-wrap:break-word;">'
f'<span style="color:#82aaff;font-weight:600;">{escaped_prompt}</span>'
f"{gen_html}"
f"{bar_html}"
f"</div>"
)
def _render_final_info(
prompt_tokens: int,
generated_tokens: int,
total_steps: int,
elapsed: float,
steering_preset: str,
steer_known: dict[int, float] | None,
) -> str:
tok_per_sec = generated_tokens / elapsed if elapsed > 0 else 0.0
info = (
f"**Prompt tokens:** {prompt_tokens} \n"
f"**Generated tokens:** {generated_tokens} \n"
f"**Diffusion steps:** {total_steps} \n"
f"**Time:** {elapsed:.2f}s ({tok_per_sec:.1f} tok/s) \n"
f"**Steering:** {steering_preset}"
)
if steer_known:
info += f" \n**Known concept IDs:** `{steer_known}`"
return info
# ---------------------------------------------------------------------------
# Streaming generation β€” reimplements the core loop from
# SteerlingGenerator.generate_full so we can yield after each step.
# ---------------------------------------------------------------------------
@spaces.GPU(duration=60)
def generate_streaming(
prompt: str,
max_new_tokens: int,
temperature: float,
top_p: float,
repetition_penalty: float,
tokens_per_step: int,
use_entropy_sampling: bool,
seed: int | None,
steering_preset: str,
custom_steer_known: str,
):
"""Generator: yields (html_viz, generated_text, info_md) tuples."""
if not prompt or not prompt.strip():
yield (
'<div style="padding:16px;color:#ff6f61;">⚠️ Please enter a prompt.</div>',
"",
"",
)
return
# --- resolve steering ---------------------------------------------------
steer_known: dict[int, float] | None = None
steer_unknown: dict[int, float] | None = None
if steering_preset != "None" and steering_preset in STEERING_PRESETS:
preset = STEERING_PRESETS[steering_preset]
steer_known = preset.get("steer_known")
steer_unknown = preset.get("steer_unknown")
if custom_steer_known and custom_steer_known.strip():
try:
parsed: dict[int, float] = {}
for pair in custom_steer_known.split(","):
pair = pair.strip()
if not pair:
continue
cid, val = pair.split(":")
parsed[int(cid.strip())] = float(val.strip())
if parsed:
steer_known = parsed
except Exception as exc:
yield (
f'<div style="padding:16px;color:#ff6f61;">⚠️ Could not parse custom steering: {exc}</div>',
"",
"",
)
return
# --- setup --------------------------------------------------------------
gen = generator # alias
tokenizer = gen.tokenizer
model = gen.model
device = gen.device
mask_id = gen.mask_token_id
eos_id = gen.eos_token_id
pad_id = gen.pad_token_id
if seed is not None and seed >= 0:
torch.manual_seed(int(seed))
prompt_ids = tokenizer.encode(prompt, add_special_tokens=False)
prompt_len = len(prompt_ids)
total_len = prompt_len + max_new_tokens
# Initialize sequence: prompt tokens + mask tokens for generation slots
x = torch.full((1, total_len), mask_id, dtype=torch.long, device=device)
if prompt_len > 0:
x[0, :prompt_len] = torch.tensor(prompt_ids, dtype=torch.long, device=device)
# Track what's finalized
is_prompt_mask = torch.zeros(total_len, dtype=torch.bool, device=device)
is_prompt_mask[:prompt_len] = True
gen_region = ~is_prompt_mask
is_finalized = is_prompt_mask.clone()
# Banned token ids
banned_ids = {mask_id}
if pad_id is not None:
banned_ids.add(pad_id)
# Build steering intervention tensors
int_known_ids, int_known_vals = None, None
int_unknown_ids, int_unknown_vals = None, None
if gen.is_interpretable and steer_known:
int_known_ids, int_known_vals = gen._build_intervention_tensors(
steer_known, total_len
)
if gen.is_interpretable and steer_unknown:
int_unknown_ids, int_unknown_vals = gen._build_intervention_tensors(
steer_unknown, total_len
)
# Lists to track generation-region state for rendering
gen_token_ids: list[int] = [mask_id] * max_new_tokens
gen_finalized: list[bool] = [False] * max_new_tokens
t0 = time.perf_counter()
tokens_generated = 0
step_count = 0
# --- yield initial state (all masked) -----------------------------------
yield (
_render_html(
prompt,
gen_token_ids,
gen_finalized,
set(),
tokenizer,
max_new_tokens,
0,
0.0,
),
"",
"*Generating…*",
)
# --- diffusion loop (under inference_mode for perf) ---------------------
with torch.inference_mode():
while tokens_generated < max_new_tokens:
still_masked = (x[0] == mask_id) & gen_region
masked_indices = still_masked.nonzero(as_tuple=False).squeeze(-1)
if masked_indices.numel() == 0:
break
if masked_indices.dim() == 0:
masked_indices = masked_indices.unsqueeze(0)
# Forward pass
if gen.is_interpretable:
logits, _ = model(
x,
use_teacher_forcing=False,
intervene_known_ids=int_known_ids,
intervene_known_vals=int_known_vals,
intervene_unknown_ids=int_unknown_ids,
intervene_unknown_vals=int_unknown_vals,
minimal_output=True,
)
else:
logits = model(x)
masked_logits = logits[0, masked_indices].clone()
# Ban special tokens
for tid in banned_ids:
masked_logits[:, tid] = -1e9
# Repetition penalty
if repetition_penalty != 1.0:
finalized_tokens = x[0, is_finalized].tolist()
for tok in set(finalized_tokens):
if tok not in banned_ids:
masked_logits[:, tok] /= repetition_penalty
# Confidence-based position selection
probs_for_conf = torch.softmax(masked_logits, dim=-1)
confidences = probs_for_conf.max(dim=-1).values
k = min(tokens_per_step, masked_indices.numel())
_, selected_pos_indices = confidences.topk(k)
step_count += 1
just_unmasked: set[int] = set()
# Fill selected positions
for pos_idx in selected_pos_indices:
seq_idx = int(masked_indices[pos_idx].item())
gen_slot = seq_idx - prompt_len # index into gen arrays
pos_logits = masked_logits[pos_idx]
# Temperature (entropy-adaptive or fixed)
if use_entropy_sampling:
pos_probs_raw = torch.softmax(pos_logits, dim=-1)
sorted_probs, _ = torch.sort(pos_probs_raw, descending=True)
cumsum = torch.cumsum(sorted_probs, dim=-1)
effective_k = max((cumsum <= top_p).sum().item() + 1, 2)
entropy = -torch.sum(
pos_probs_raw * torch.log(pos_probs_raw + 1e-10)
)
normalized_entropy = min(
1.0, entropy.item() / math.log(effective_k)
)
adaptive_temp = 0.3 + 0.4 * normalized_entropy
pos_probs = torch.softmax(pos_logits / adaptive_temp, dim=-1)
else:
pos_probs = torch.softmax(
pos_logits / max(temperature, 1e-8), dim=-1
)
tok = _sample_top_p(pos_probs, top_p)
x[0, seq_idx] = tok
is_finalized[seq_idx] = True
tokens_generated += 1
if 0 <= gen_slot < max_new_tokens:
gen_token_ids[gen_slot] = tok
gen_finalized[gen_slot] = True
just_unmasked.add(gen_slot)
if eos_id is not None and tok == eos_id:
break
elapsed = time.perf_counter() - t0
# Decode the current generated text (finalized tokens only, in order)
current_gen_tokens = []
for i in range(max_new_tokens):
if gen_finalized[i]:
current_gen_tokens.append(gen_token_ids[i])
else:
break
current_text = (
tokenizer.decode(current_gen_tokens) if current_gen_tokens else ""
)
yield (
_render_html(
prompt,
gen_token_ids,
gen_finalized,
just_unmasked,
tokenizer,
max_new_tokens,
step_count,
elapsed,
),
current_text,
f"*Step {step_count} β€” {tokens_generated}/{max_new_tokens} tokens β€” {elapsed:.1f}s*",
)
# Check EOS
if eos_id is not None and (x[0, gen_region] == eos_id).any():
break
# --- final yield --------------------------------------------------------
elapsed = time.perf_counter() - t0
final_tokens = []
for i in range(max_new_tokens):
if gen_finalized[i]:
final_tokens.append(gen_token_ids[i])
else:
break
final_text = tokenizer.decode(final_tokens) if final_tokens else ""
final_info = _render_final_info(
prompt_tokens=prompt_len,
generated_tokens=len(final_tokens),
total_steps=step_count,
elapsed=elapsed,
steering_preset=steering_preset,
steer_known=steer_known,
)
# Final HTML without any highlight
yield (
_render_html(
prompt,
gen_token_ids,
gen_finalized,
set(),
tokenizer,
max_new_tokens,
step_count,
elapsed,
),
final_text,
final_info,
)
def _sample_top_p(probs: torch.Tensor, top_p: float) -> int:
sorted_probs, sorted_indices = torch.sort(probs, descending=True)
cumulative = torch.cumsum(sorted_probs, dim=-1)
cutoff_mask = cumulative <= top_p
cutoff_mask[0] = True
cutoff_idx = min(cutoff_mask.sum().item() + 1, len(sorted_probs))
truncated = sorted_probs[:cutoff_idx]
truncated = truncated / truncated.sum()
return int(sorted_indices[torch.multinomial(truncated, 1)].item())
# ---------------------------------------------------------------------------
# Gradio UI
# ---------------------------------------------------------------------------
DESCRIPTION = dedent("""\
# 🧭 Steerling-8B Demo
**[Steerling-8B](https://huggingface.co/guidelabs/steerling-8b)** is an
8 billion parameter *causal diffusion* language model with interpretable
concept steering, built by [Guide Labs](https://www.guidelabs.ai/).
Unlike standard autoregressive LLMs, Steerling generates text by
**iteratively unmasking tokens in order of confidence** β€” the model fills
in positions where it is most certain first. Watch the diffusion process
live below!
### ✨ Key Features
| Feature | Description |
|---|---|
| 🎲 **Diffusion decoding** | Confidence-based unmasking instead of left-to-right |
| πŸ” **Interpretability** | Hidden states β†’ known + unknown concept decomposition |
| πŸŽ›οΈ **Concept steering** | Amplify or suppress concepts to guide generation |
| πŸ“ **Block-causal attention** | Bidirectional within 64-token blocks, causal across |
> ℹ️ This Space runs on **ZeroGPU** (NVIDIA H200). Generation may be
> queued briefly while a GPU is allocated.
""")
ARTICLE = dedent("""\
---
### How It Works
```
hidden β†’ known_features + unknown_features + Ξ΅ = composed β†’ logits
```
- **known_features** β€” weighted sum of top-k learned concept embeddings (interpretable)
- **unknown_features** β€” residual captured by a factorized unknown head
- **Ξ΅** β€” small correction for reconstruction fidelity
The **live visualization** above shows the diffusion process in action:
- <span style="color:#82aaff;">**Blue text**</span> = your prompt
- <span style="background:#ffe066;color:#1a1a2e;padding:0 3px;border-radius:3px;">**Highlighted**</span> = just unmasked this step
- <span style="color:#555;">β–’</span> = still masked (waiting to be filled)
Unlike autoregressive models that generate left-to-right, Steerling fills in
the **most confident positions first**, regardless of order.
### Links
- πŸ“„ [Model Card](https://huggingface.co/guidelabs/steerling-8b)
- πŸ’» [GitHub](https://github.com/guidelabs/steerling)
- 🏒 [Guide Labs](https://www.guidelabs.ai/)
- πŸ“ [Architecture Blog Post](https://www.guidelabs.ai/post/block-causal-diffusion-language-model/)
""")
CSS = """
footer { display: none !important; }
.generating { border: none !important; }
"""
with gr.Blocks(css=CSS, title="Steerling-8B Demo", theme=gr.themes.Soft()) as demo:
gr.Markdown(DESCRIPTION)
with gr.Row():
# ── Left column: inputs ───────────────────────────────────────
with gr.Column(scale=1):
prompt = gr.Textbox(
label="Prompt",
placeholder="Enter your prompt here…",
lines=4,
value=EXAMPLE_PROMPTS[0],
)
with gr.Accordion("βš™οΈ Generation Settings", open=False):
max_new_tokens = gr.Slider(
16,
512,
value=128,
step=16,
label="Max new tokens",
)
temperature = gr.Slider(
0.0,
2.0,
value=1.0,
step=0.05,
label="Temperature",
info="Overridden when entropy sampling is on",
)
top_p = gr.Slider(
0.1,
1.0,
value=0.9,
step=0.05,
label="Top-p (nucleus)",
)
repetition_penalty = gr.Slider(
1.0,
2.0,
value=1.2,
step=0.05,
label="Repetition penalty",
)
tokens_per_step = gr.Slider(
1,
64,
value=1,
step=1,
label="Tokens per step",
info="Unmask multiple positions per diffusion step (faster but noisier)",
)
use_entropy_sampling = gr.Checkbox(
value=True,
label="Entropy-adaptive sampling",
info="Automatically adjusts temperature (0.3–0.7) based on model uncertainty",
)
seed = gr.Number(
value=42,
label="Seed (-1 = random)",
precision=0,
)
with gr.Accordion("πŸŽ›οΈ Concept Steering", open=False):
gr.Markdown(
"Steerling decomposes hidden states into **known concepts**. "
"You can amplify (positive weight) or suppress (negative weight) "
"specific concept IDs to steer generation."
)
steering_preset = gr.Dropdown(
choices=list(STEERING_PRESETS.keys()),
value="None",
label="Steering preset",
)
custom_steer_known = gr.Textbox(
label="Custom known-concept overrides",
placeholder="e.g. 102:2.5, 541:-1.0",
info="Comma-separated id:weight pairs. Overrides the preset.",
)
generate_btn = gr.Button(
"πŸš€ Generate",
variant="primary",
size="lg",
)
gr.Examples(
examples=[[p] for p in EXAMPLE_PROMPTS],
inputs=[prompt],
label="Example prompts",
)
# ── Right column: outputs ─────────────────────────────────────
with gr.Column(scale=1):
viz_html = gr.HTML(
label="Live diffusion",
value=(
'<div style="padding:20px;text-align:center;color:#555;'
'font-style:italic;">Press Generate to watch the diffusion '
"process unfold…</div>"
),
)
generated_output = gr.Textbox(
label="Generated text (plain)",
lines=6,
interactive=False,
)
info_md = gr.Markdown(label="Generation info")
# Wire inputs list
inputs = [
prompt,
max_new_tokens,
temperature,
top_p,
repetition_penalty,
tokens_per_step,
use_entropy_sampling,
seed,
steering_preset,
custom_steer_known,
]
outputs = [viz_html, generated_output, info_md]
generate_btn.click(
fn=generate_streaming,
inputs=inputs,
outputs=outputs,
)
prompt.submit(
fn=generate_streaming,
inputs=inputs,
outputs=outputs,
)
gr.Markdown(ARTICLE)
if __name__ == "__main__":
demo.launch()