basilboy's picture
Update app.py
b862b43 verified
# app.py
import os, re, math, random, json
import torch
import torch.nn as nn
import torch.nn.functional as F
import gradio as gr
from transformers import AutoTokenizer
from safetensors.torch import load_file as load_sft
from huggingface_hub import snapshot_download
torch.set_default_dtype(torch.float32)
# ===============================================
# Default config (from your training notes)
# ===============================================
DEFAULT_CONF = {
"embed_dim": 1024,
"num_heads": 8,
"expansion_factor": 4,
"num_blocks": 8,
"radius": 16,
"tokenizer_name": "gpt2",
}
# ===============================================
# Minimal CNA (inference-ready)
# ===============================================
class AttnBlock(nn.Module):
def __init__(self, embed_dim, num_heads, expansion_factor):
super().__init__()
assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.norm1 = nn.LayerNorm(embed_dim)
self.QKV = nn.Linear(embed_dim, embed_dim * 3)
self.Wo = nn.Linear(embed_dim, embed_dim)
self.norm2 = nn.LayerNorm(embed_dim)
self.mlp = nn.Sequential(
nn.Linear(embed_dim, embed_dim * expansion_factor),
nn.GELU(),
nn.Linear(embed_dim * expansion_factor, embed_dim),
)
# zero-init residual branches (match training)
nn.init.zeros_(self.Wo.weight); nn.init.zeros_(self.Wo.bias)
nn.init.zeros_(self.mlp[-1].weight); nn.init.zeros_(self.mlp[-1].bias)
def rope(self, Qh, Kh_seq, cos, sin):
Qe = Qh[..., 0::2]; Qo = Qh[..., 1::2]
ce = cos[..., 0::2]; se = sin[..., 0::2]
Qr_e = Qe * ce - Qo * se
Qr_o = Qe * se + Qo * ce
Qh2 = torch.empty_like(Qh); Qh2[..., 0::2] = Qr_e; Qh2[..., 1::2] = Qr_o
Ke = Kh_seq[..., 0::2]; Ko = Kh_seq[..., 1::2]
Kr_e = Ke * ce - Ko * se
Kr_o = Ke * se + Ko * ce
Kh2 = torch.empty_like(Kh_seq); Kh2[..., 0::2] = Kr_e; Kh2[..., 1::2] = Kr_o
return Qh2, Kh2
def forward(self, x, rope, radius):
# keep LN inputs & params same dtype
if x.dtype != self.norm1.weight.dtype:
x = x.to(self.norm1.weight.dtype)
h = self.norm1(x)
B, S, E = h.shape
cos, sin = rope
nh, hd = self.num_heads, self.head_dim
cos = cos.to(h.dtype).to(h.device).permute(0,2,1,3) # [1,1,S,hd]
sin = sin.to(h.dtype).to(h.device).permute(0,2,1,3)
# local band mask
idx = torch.arange(S, device=h.device)
idx_dist = (idx.view(1, S) - idx.view(S, 1)).abs()
neg_inf = torch.finfo(h.dtype).min
mask = torch.full((S, S), neg_inf, dtype=h.dtype, device=h.device)
mask[idx_dist <= int(radius)] = 0
mask = mask.view(1, 1, S, S)
qkv = self.QKV(h)
q, k, v = qkv.chunk(3, dim=-1)
Qh = q.view(B,S,nh,hd).permute(0,2,1,3).contiguous()
Kh_seq = k.view(B,S,nh,hd).permute(0,2,1,3).contiguous()
Vh = v.view(B,S,nh,hd).permute(0,2,1,3).contiguous()
assert hd % 2 == 0, "rope needs even head_dim"
Qh, Kh_seq = self.rope(Qh, Kh_seq, cos, sin)
Kh = Kh_seq.permute(0,1,3,2).contiguous()
logits = (Qh @ Kh) * (hd ** -0.5)
attn = F.softmax(logits + mask, dim=-1) @ Vh
attn = attn.permute(0,2,1,3).contiguous().view(B,S,E)
x = x + self.Wo(attn)
x = x + self.mlp(self.norm2(x))
return x
class CNA(nn.Module):
def __init__(self, embed_dim, num_heads, expansion_factor, num_blocks, radius, vocab_size):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.expansion_factor = expansion_factor
self.num_blocks = num_blocks
self.vocab_size = vocab_size
self.radius = radius
self.tok_emb = nn.Embedding(vocab_size, embed_dim)
self.blocks = nn.ModuleList([AttnBlock(embed_dim, num_heads, expansion_factor) for _ in range(num_blocks)])
self.proj = nn.Linear(embed_dim, vocab_size)
def _rope_seq(self, S, hd, device, dtype, base=10000.0):
pos = torch.arange(S, device=device, dtype=dtype)
half = hd // 2
idx = torch.arange(half, device=device, dtype=dtype)
inv = base ** (-idx / half)
ang = pos[:, None] * inv[None, :]
cos = ang.cos().unsqueeze(0).unsqueeze(2)
sin = ang.sin().unsqueeze(0).unsqueeze(2)
cos = torch.stack((cos, cos), dim=-1).reshape(1, S, 1, hd)
sin = torch.stack((sin, sin), dim=-1).reshape(1, S, 1, hd)
return cos, sin
def forward(self, x):
if x.dtype == torch.long and x.dim() == 2:
h = self.tok_emb(x)
else:
h = x
# ensure embeddings/activations dtype follows model dtype
target_dtype = next(self.parameters()).dtype
if h.dtype != target_dtype:
h = h.to(target_dtype)
B, S, E = h.shape
hd = self.embed_dim // self.num_heads
cos, sin = self._rope_seq(S, hd, h.device, h.dtype)
for blk in self.blocks:
h = blk(h, rope=(cos, sin), radius=self.radius)
return self.proj(h)
# ===============================================
# Helpers
# ===============================================
def to_batch2(ids_like) -> torch.Tensor:
"""
Normalize ids_like (list, [[...]], tensor) to int64 shape [1, S].
Accepts [S], [1,S], [1,1,S]; returns [1,S].
"""
x = torch.tensor(ids_like, dtype=torch.long)
if x.dim() == 1:
x = x.unsqueeze(0) # [S] -> [1,S]
elif x.dim() == 3 and x.shape[0] == 1 and x.shape[1] == 1:
x = x.squeeze(1) # [1,1,S] -> [1,S]
elif x.dim() != 2:
x = x.view(1, -1) # fallback reshape
return x
def infer_expansion_factor_from_state(state, embed_dim):
for key in ("blocks.0.mlp.0.weight", "blocks.0.mlp.2.weight"):
if key in state:
W = state[key]
if key.endswith("0.weight"):
return int(W.shape[0] // embed_dim)
else:
return int(W.shape[1] // embed_dim)
return DEFAULT_CONF["expansion_factor"]
@torch.no_grad()
def decode(ids, tokenizer, max_chars=1000):
s = tokenizer.decode(ids.tolist(), skip_special_tokens=True)
s = s.replace("\n", " ")
return s[:max_chars] + ("…" if len(s) > max_chars else "")
@torch.no_grad()
def model_logits(model, x):
return model(x)
def to_fixed_len_ids(text, tokenizer, seqlen, pad_mode="random", rnd=None):
if rnd is None:
rnd = random.Random()
ids = tokenizer.encode(text, add_special_tokens=False)
V = tokenizer.vocab_size
if len(ids) >= seqlen:
ids = ids[:seqlen]
else:
need = seqlen - len(ids)
if pad_mode == "eos" and tokenizer.eos_token_id is not None:
ids = ids + [tokenizer.eos_token_id] * need
else:
ids = ids + [rnd.randrange(V) for _ in range(need)]
return torch.tensor(ids, dtype=torch.long).unsqueeze(0)
def apply_noise_ops(x, tokenizer, indices_csv, add_noise_left, add_noise_right, seqlen, seed=0):
rnd = random.Random(seed)
V = tokenizer.vocab_size
x = x.clone()
idxs = set()
if indices_csv and indices_csv.strip():
for part in indices_csv.split(","):
part = part.strip()
if not part: continue
if "-" in part:
a, b = part.split("-", 1)
try:
a, b = int(a), int(b)
for j in range(min(a,b), max(a,b)+1):
idxs.add(j)
except:
pass
else:
try:
idxs.add(int(part))
except:
pass
for j in idxs:
if 0 <= j < x.shape[1]:
x[0, j] = rnd.randrange(V)
if add_noise_left > 0:
prefix = torch.tensor([rnd.randrange(V) for _ in range(int(add_noise_left))], dtype=torch.long).unsqueeze(0)
x = torch.cat([prefix, x], dim=1)
if add_noise_right > 0:
suffix = torch.tensor([rnd.randrange(V) for _ in range(int(add_noise_right))], dtype=torch.long).unsqueeze(0)
x = torch.cat([x, suffix], dim=1)
if x.shape[1] > seqlen:
x = x[:, :seqlen]
elif x.shape[1] < seqlen:
need = seqlen - x.shape[1]
pad = torch.tensor([rnd.randrange(V) for _ in range(need)], dtype=torch.long).unsqueeze(0)
x = torch.cat([x, pad], dim=1)
return x
@torch.no_grad()
def sample_from_logits(logits_row, temperature=1.0, current_token=None, exclude_current=True):
if temperature <= 0:
return int(torch.argmax(logits_row).item())
scaled = logits_row / float(temperature)
probs = torch.softmax(scaled, dim=-1)
if exclude_current and current_token is not None:
probs = probs.clone()
probs[current_token] = 0.0
s = probs.sum()
if s.item() <= 0:
return int(torch.argmax(logits_row).item())
probs = probs / s
return int(torch.multinomial(probs, 1).item())
# ===============================================
# Weight loading (file / folder / HF Hub)
# ===============================================
DEFAULT_CKPT = os.environ.get("CKPT_PATH", "ckpt_latest.pt")
DEFAULT_WEIGHTS_DIR = os.environ.get("WEIGHTS_DIR", "weights_latest")
def _read_config_from_dict_or_infer(state, cfg):
merged = {**DEFAULT_CONF, **(cfg or {})}
if "tok_emb.weight" in state:
merged["embed_dim"] = state["tok_emb.weight"].shape[1]
block_idxs = [int(m.group(1)) for k in state.keys() for m in [re.match(r"blocks\.(\d+)\.", k)] if m]
if block_idxs:
merged["num_blocks"] = max(block_idxs) + 1
if "blocks.0.mlp.0.weight" in state or "blocks.0.mlp.2.weight" in state:
merged["expansion_factor"] = infer_expansion_factor_from_state(state, merged["embed_dim"])
if not merged.get("tokenizer_name"):
merged["tokenizer_name"] = "gpt2"
return merged
def _is_state_dict(obj):
if isinstance(obj, dict) and obj:
sample_val = next(iter(obj.values()))
return isinstance(sample_val, torch.Tensor)
return False
def _load_state_from_pt(path: str):
obj = torch.load(path, map_location="cpu")
if isinstance(obj, dict) and "model" in obj and isinstance(obj["model"], dict):
state = obj["model"]
cfg = obj.get("config", {}) or {}
if "tokenizer_name" in obj:
cfg = {**cfg, "tokenizer_name": obj["tokenizer_name"]}
return state, cfg
if _is_state_dict(obj):
return obj, {}
raise ValueError(f"Unsupported .pt format at {path}: expected a state_dict or a payload with 'model'.")
def _merge_state_dicts(dicts):
merged = {}
for d in dicts:
for k, v in d.items():
merged[k] = v
return merged
def _load_state_from_folder(weights_dir: str):
if not os.path.isdir(weights_dir):
raise FileNotFoundError(f"Folder not found: {weights_dir}")
cfg_path = os.path.join(weights_dir, "config.json")
cfg = {}
if os.path.exists(cfg_path):
with open(cfg_path, "r") as f:
cfg = json.load(f)
files = sorted(os.listdir(weights_dir))
sft_files = [f for f in files if f.endswith(".safetensors")]
pt_files = [f for f in files if f.endswith(".pt") or f.endswith(".bin")]
state = None
if "model.safetensors" in sft_files:
state = load_sft(os.path.join(weights_dir, "model.safetensors"))
elif sft_files:
parts = [load_sft(os.path.join(weights_dir, f)) for f in sft_files]
state = _merge_state_dicts(parts)
elif pt_files:
parts = []
for f in pt_files:
part = torch.load(os.path.join(weights_dir, f), map_location="cpu")
if isinstance(part, dict) and "model" in part and isinstance(part["model"], dict):
parts.append(part["model"])
if "config" in part and isinstance(part["config"], dict):
cfg = {**cfg, **part["config"]}
if "tokenizer_name" in part:
cfg.setdefault("tokenizer_name", part["tokenizer_name"])
elif _is_state_dict(part):
parts.append(part)
else:
raise ValueError(f"Unsupported shard format: {f}")
state = _merge_state_dicts(parts)
else:
raise FileNotFoundError(
f"No weights found in {weights_dir}. Expected .safetensors or .pt files."
)
return state, cfg
def _load_state_from_hub(repo_id: str, subfolder: str | None = None, revision: str | None = None):
cache_dir = snapshot_download(repo_id=repo_id, revision=revision, allow_patterns=None)
path = os.path.join(cache_dir, subfolder) if subfolder else cache_dir
return _load_state_from_folder(path)
def load_model(source: str):
src = source or ""
state, cfg = None, {}
if os.path.isfile(src) and (src.endswith(".pt") or src.endswith(".bin")):
state, cfg = _load_state_from_pt(src)
elif os.path.isdir(src):
state, cfg = _load_state_from_folder(src)
elif "/" in src: # Hub repo id
subfolder = os.environ.get("WEIGHTS_SUBFOLDER") or None
revision = os.environ.get("WEIGHTS_REVISION") or None
state, cfg = _load_state_from_hub(src, subfolder=subfolder, revision=revision)
else:
# fallbacks
if os.path.isfile("weights_latest.pt"):
state, cfg = _load_state_from_pt("weights_latest.pt")
elif os.path.isfile(DEFAULT_CKPT):
state, cfg = _load_state_from_pt(DEFAULT_CKPT)
elif os.path.isdir(DEFAULT_WEIGHTS_DIR):
state, cfg = _load_state_from_folder(DEFAULT_WEIGHTS_DIR)
else:
raise FileNotFoundError(
f"Could not resolve weights from '{src}'. Tried file (.pt), folder, hub repo id, "
f"then defaults ('{DEFAULT_CKPT}', '{DEFAULT_WEIGHTS_DIR}')."
)
conf = _read_config_from_dict_or_infer(state, cfg)
# Tokenizer
tokenizer = AutoTokenizer.from_pretrained(conf["tokenizer_name"], use_fast=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.model_max_length = 1_000_000_000
vocab_size = tokenizer.vocab_size
# Build model
model = CNA(
conf["embed_dim"], conf["num_heads"], conf["expansion_factor"],
conf["num_blocks"], conf["radius"], vocab_size
)
# Load state (tolerate projection size mismatch)
missing, unexpected = model.load_state_dict(state, strict=False)
if any(k.startswith("proj.") for k in missing):
with torch.no_grad():
nn.init.normal_(model.proj.weight, std=0.02)
nn.init.zeros_(model.proj.bias)
else:
model.load_state_dict(state, strict=True)
# enforce float32 across params & buffers
model = model.to(torch.float32)
with torch.no_grad():
for p in model.parameters():
if p.dtype.is_floating_point:
p.data = p.data.float()
for _, buf in model.named_buffers():
if buf.dtype.is_floating_point:
buf.data = buf.data.float()
model.eval()
return model, tokenizer, conf["radius"]
model_cache = {"model": None, "tokenizer": None, "radius": None, "ckpt": None}
def _auto_default_source():
env = os.environ.get("WEIGHTS_SOURCE")
if env:
return env
if os.path.isdir("weights_latest"):
return "weights_latest"
for name in ["weights_latest.pt", "ckpt_latest.pt"]:
if os.path.isfile(name):
return name
for f in sorted(os.listdir(".")):
if f.endswith(".pt") or f.endswith(".safetensors"):
return f
return "weights_latest.pt"
def ensure_model(source_path_or_repo):
src = source_path_or_repo or _auto_default_source()
if model_cache["model"] is None or model_cache["ckpt"] != src:
m, tok, rad = load_model(src)
model_cache.update({"model": m, "tokenizer": tok, "radius": rad, "ckpt": src})
# ===============================================
# Strategy 1 (random position) with argmax / sample
# ===============================================
@torch.no_grad()
def step_strategy1(model, x, mode="argmax", temperature=1.0, exclude_current=True):
S = x.shape[1]
pos = int(torch.randint(0, S, (1,)).item())
logits_pos = model_logits(model, x)[0, pos]
if mode == "sample":
cur_tok = int(x[0, pos].item())
new_tok = sample_from_logits(logits_pos, temperature=float(temperature),
current_token=cur_tok, exclude_current=bool(exclude_current))
x[0, pos] = new_tok
else:
x[0, pos] = int(torch.argmax(logits_pos).item())
return x
# ===============================================
# Gradio callbacks
# ===============================================
def init_random(src, seqlen, seed):
ensure_model(src)
random.seed(seed); torch.manual_seed(seed)
V = model_cache["tokenizer"].vocab_size
x = torch.randint(0, V, (1, int(seqlen)))
txt = decode(x[0], model_cache["tokenizer"])
return x.tolist(), txt, f"Initialized random sequence (len={int(seqlen)})"
def to_ranges(indices):
"""Compress a sorted list of token indices into 'a-b' CSV."""
if not indices:
return ""
indices = sorted(set(indices))
ranges = []
start = prev = indices[0]
for i in indices[1:]:
if i == prev + 1:
prev = i
else:
ranges.append((start, prev))
start = prev = i
ranges.append((start, prev))
parts = [f"{a}-{b}" if a != b else f"{a}" for a, b in ranges]
return ", ".join(parts)
def capture_selection(text, seqlen, current_ids, evt: gr.SelectData | None = None):
"""
Map highlighted character span in `text` to token index ranges using tokenizer offsets.
Auto-fills the indices box so you can 'Noise Selection'.
"""
ensure_model(None)
tok = model_cache["tokenizer"]
if not text:
return gr.update(), "No text to select from."
# Try to read (start, end) from the event payload
start, end = None, None
if evt is not None:
try:
# gradio SelectData for Textbox exposes .index = (start_char, end_char)
start, end = evt.index
except Exception:
pass
# Fallback: nothing selected
if start is None or end is None or start == end:
return gr.update(), "No selection detected (drag to highlight)."
# Bound the indices defensively
start = max(0, min(len(text), int(start)))
end = max(0, min(len(text), int(end)))
# Get per-token char offsets from the fast tokenizer
enc = tok(text, add_special_tokens=False, return_offsets_mapping=True)
offsets = enc["offset_mapping"] # list of (s,e) per token
token_idxs = []
for i, (s, e) in enumerate(offsets):
if s is None or e is None:
continue
# overlap if token span intersects [start, end)
if max(s, start) < min(e, end):
token_idxs.append(i)
if not token_idxs:
return gr.update(), "Selection didn't hit any tokens (maybe whitespace)."
# Clip to current sequence length (so we don't index beyond S)
S = int(seqlen)
token_idxs = [i for i in token_idxs if i < S]
if not token_idxs:
return gr.update(), "Selected span maps beyond current sequence length."
indices_csv = to_ranges(token_idxs)
return indices_csv, f"Selected chars [{start}:{end}) → tokens {indices_csv}"
def noise_selection(src, state_ids, seqlen, indices_csv, seed):
# Reuse apply_noise but force prepend/append noise to zero
return apply_noise(src, state_ids, seqlen, indices_csv, 0, 0, seed)
def apply_noise(src, state_ids, seqlen, indices_csv, add_left, add_right, seed):
ensure_model(src)
tok = model_cache["tokenizer"]
S = int(seqlen)
if state_ids is None or len(state_ids) == 0:
V = tok.vocab_size
base = torch.randint(0, V, (1, S))
else:
base = to_batch2(state_ids)
x = apply_noise_ops(base, tok, indices_csv, int(add_left or 0), int(add_right or 0), S, seed=seed)
txt = decode(x[0], tok)
return x.tolist(), txt, "Applied noise"
def step_once(src, state_ids, mode, temperature, exclude_current):
ensure_model(src)
tok = model_cache["tokenizer"]
if state_ids is None or len(state_ids) == 0:
return None, "", "No sequence to step — initialize first."
x = to_batch2(state_ids)
x = step_strategy1(model_cache["model"], x, mode=mode, temperature=temperature, exclude_current=exclude_current)
txt = decode(x[0], tok)
return x.tolist(), txt, f"Stepped 1 iteration ({mode})"
def live_denoise(src, state_ids, steps, snap_every, seed, mode, temperature, exclude_current):
ensure_model(src)
tok = model_cache["tokenizer"]
if state_ids is None or len(state_ids) == 0:
return
random.seed(seed); torch.manual_seed(seed)
x = to_batch2(state_ids)
total = int(steps); snap = max(1, int(snap_every))
for t in range(1, total + 1):
x = step_strategy1(model_cache["model"], x, mode=mode, temperature=temperature, exclude_current=exclude_current)
if (t % snap == 0) or (t == total):
txt = decode(x[0], tok)
yield x.tolist(), txt, f"Live denoise… step {t}/{total} ({mode})"
# ===============================================
# UI (single mode)
# ===============================================
with gr.Blocks(title="Self Organising Text Demo") as demo:
gr.Markdown(
"""
# Self Organising Text Demo
Watch text self organise using only local attention.
"""
)
default_source = os.environ.get("WEIGHTS_SOURCE", None)
if default_source is None:
default_source = _auto_default_source()
with gr.Row():
src = gr.Textbox(value=default_source, label="Weights (file / folder / HF repo id)")
seqlen = gr.Slider(10, 512, value=50, step=1, label="Sequence length (S)")
seed = gr.Slider(0, 10000, value=0, step=1, label="Seed")
ids_state = gr.State(value=None)
with gr.Row():
current_text = gr.Textbox(lines=8, label="Current text", interactive=True)
status = gr.Markdown("Ready.")
gr.Markdown("### Initialize & Denoise")
with gr.Row():
btn_random = gr.Button("Initialize Random")
steps = gr.Slider(1, 2000, value=100, step=1, label="Denoise steps (N)") # default 100
snap_every = gr.Slider(1, 100, value=1, step=1, label="Update every K steps") # default 1
with gr.Row():
update_mode = gr.Radio(
choices=["argmax", "sample"],
value="sample", # default to sampling
label="Update rule"
)
temperature = gr.Slider(minimum=0.0, maximum=5.0, value=1.0, step=0.05, label="Temperature (sampling)")
exclude_current = gr.Checkbox(value=True, label="Exclude current token when sampling")
with gr.Row():
btn_step_once = gr.Button("Step Once")
btn_live = gr.Button("Denoise Live (streaming)")
gr.Markdown("### Noise by Indices")
with gr.Row():
indices_csv = gr.Textbox(
label="Positions to noise (enter like: 0, 5, 10-20)",
placeholder="e.g., 0, 5, 10-20"
)
with gr.Row():
add_left = gr.Number(value=0, precision=0, label="Noise tokens to add at START")
add_right = gr.Number(value=0, precision=0, label="Noise tokens to add at END")
btn_apply_noise = gr.Button("Apply Noise")
# --- Wiring ---
btn_random.click(init_random, [src, seqlen, seed], [ids_state, current_text, status])
# Manual indices + prepend/append noise
btn_apply_noise.click(
apply_noise,
[src, ids_state, seqlen, indices_csv, add_left, add_right, seed],
[ids_state, current_text, status]
)
btn_step_once.click(
step_once,
[src, ids_state, update_mode, temperature, exclude_current],
[ids_state, current_text, status]
)
btn_live.click(
live_denoise,
[src, ids_state, steps, snap_every, seed, update_mode, temperature, exclude_current],
[ids_state, current_text, status],
show_progress=True
)
demo.queue().launch()