# app.py import os, re, math, random, json import torch import torch.nn as nn import torch.nn.functional as F import gradio as gr from transformers import AutoTokenizer from safetensors.torch import load_file as load_sft from huggingface_hub import snapshot_download torch.set_default_dtype(torch.float32) # =============================================== # Default config (from your training notes) # =============================================== DEFAULT_CONF = { "embed_dim": 1024, "num_heads": 8, "expansion_factor": 4, "num_blocks": 8, "radius": 16, "tokenizer_name": "gpt2", } # =============================================== # Minimal CNA (inference-ready) # =============================================== class AttnBlock(nn.Module): def __init__(self, embed_dim, num_heads, expansion_factor): super().__init__() assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads" self.embed_dim = embed_dim self.num_heads = num_heads self.head_dim = embed_dim // num_heads self.norm1 = nn.LayerNorm(embed_dim) self.QKV = nn.Linear(embed_dim, embed_dim * 3) self.Wo = nn.Linear(embed_dim, embed_dim) self.norm2 = nn.LayerNorm(embed_dim) self.mlp = nn.Sequential( nn.Linear(embed_dim, embed_dim * expansion_factor), nn.GELU(), nn.Linear(embed_dim * expansion_factor, embed_dim), ) # zero-init residual branches (match training) nn.init.zeros_(self.Wo.weight); nn.init.zeros_(self.Wo.bias) nn.init.zeros_(self.mlp[-1].weight); nn.init.zeros_(self.mlp[-1].bias) def rope(self, Qh, Kh_seq, cos, sin): Qe = Qh[..., 0::2]; Qo = Qh[..., 1::2] ce = cos[..., 0::2]; se = sin[..., 0::2] Qr_e = Qe * ce - Qo * se Qr_o = Qe * se + Qo * ce Qh2 = torch.empty_like(Qh); Qh2[..., 0::2] = Qr_e; Qh2[..., 1::2] = Qr_o Ke = Kh_seq[..., 0::2]; Ko = Kh_seq[..., 1::2] Kr_e = Ke * ce - Ko * se Kr_o = Ke * se + Ko * ce Kh2 = torch.empty_like(Kh_seq); Kh2[..., 0::2] = Kr_e; Kh2[..., 1::2] = Kr_o return Qh2, Kh2 def forward(self, x, rope, radius): # keep LN inputs & params same dtype if x.dtype != self.norm1.weight.dtype: x = x.to(self.norm1.weight.dtype) h = self.norm1(x) B, S, E = h.shape cos, sin = rope nh, hd = self.num_heads, self.head_dim cos = cos.to(h.dtype).to(h.device).permute(0,2,1,3) # [1,1,S,hd] sin = sin.to(h.dtype).to(h.device).permute(0,2,1,3) # local band mask idx = torch.arange(S, device=h.device) idx_dist = (idx.view(1, S) - idx.view(S, 1)).abs() neg_inf = torch.finfo(h.dtype).min mask = torch.full((S, S), neg_inf, dtype=h.dtype, device=h.device) mask[idx_dist <= int(radius)] = 0 mask = mask.view(1, 1, S, S) qkv = self.QKV(h) q, k, v = qkv.chunk(3, dim=-1) Qh = q.view(B,S,nh,hd).permute(0,2,1,3).contiguous() Kh_seq = k.view(B,S,nh,hd).permute(0,2,1,3).contiguous() Vh = v.view(B,S,nh,hd).permute(0,2,1,3).contiguous() assert hd % 2 == 0, "rope needs even head_dim" Qh, Kh_seq = self.rope(Qh, Kh_seq, cos, sin) Kh = Kh_seq.permute(0,1,3,2).contiguous() logits = (Qh @ Kh) * (hd ** -0.5) attn = F.softmax(logits + mask, dim=-1) @ Vh attn = attn.permute(0,2,1,3).contiguous().view(B,S,E) x = x + self.Wo(attn) x = x + self.mlp(self.norm2(x)) return x class CNA(nn.Module): def __init__(self, embed_dim, num_heads, expansion_factor, num_blocks, radius, vocab_size): super().__init__() self.embed_dim = embed_dim self.num_heads = num_heads self.expansion_factor = expansion_factor self.num_blocks = num_blocks self.vocab_size = vocab_size self.radius = radius self.tok_emb = nn.Embedding(vocab_size, embed_dim) self.blocks = nn.ModuleList([AttnBlock(embed_dim, num_heads, expansion_factor) for _ in range(num_blocks)]) self.proj = nn.Linear(embed_dim, vocab_size) def _rope_seq(self, S, hd, device, dtype, base=10000.0): pos = torch.arange(S, device=device, dtype=dtype) half = hd // 2 idx = torch.arange(half, device=device, dtype=dtype) inv = base ** (-idx / half) ang = pos[:, None] * inv[None, :] cos = ang.cos().unsqueeze(0).unsqueeze(2) sin = ang.sin().unsqueeze(0).unsqueeze(2) cos = torch.stack((cos, cos), dim=-1).reshape(1, S, 1, hd) sin = torch.stack((sin, sin), dim=-1).reshape(1, S, 1, hd) return cos, sin def forward(self, x): if x.dtype == torch.long and x.dim() == 2: h = self.tok_emb(x) else: h = x # ensure embeddings/activations dtype follows model dtype target_dtype = next(self.parameters()).dtype if h.dtype != target_dtype: h = h.to(target_dtype) B, S, E = h.shape hd = self.embed_dim // self.num_heads cos, sin = self._rope_seq(S, hd, h.device, h.dtype) for blk in self.blocks: h = blk(h, rope=(cos, sin), radius=self.radius) return self.proj(h) # =============================================== # Helpers # =============================================== def to_batch2(ids_like) -> torch.Tensor: """ Normalize ids_like (list, [[...]], tensor) to int64 shape [1, S]. Accepts [S], [1,S], [1,1,S]; returns [1,S]. """ x = torch.tensor(ids_like, dtype=torch.long) if x.dim() == 1: x = x.unsqueeze(0) # [S] -> [1,S] elif x.dim() == 3 and x.shape[0] == 1 and x.shape[1] == 1: x = x.squeeze(1) # [1,1,S] -> [1,S] elif x.dim() != 2: x = x.view(1, -1) # fallback reshape return x def infer_expansion_factor_from_state(state, embed_dim): for key in ("blocks.0.mlp.0.weight", "blocks.0.mlp.2.weight"): if key in state: W = state[key] if key.endswith("0.weight"): return int(W.shape[0] // embed_dim) else: return int(W.shape[1] // embed_dim) return DEFAULT_CONF["expansion_factor"] @torch.no_grad() def decode(ids, tokenizer, max_chars=1000): s = tokenizer.decode(ids.tolist(), skip_special_tokens=True) s = s.replace("\n", " ") return s[:max_chars] + ("…" if len(s) > max_chars else "") @torch.no_grad() def model_logits(model, x): return model(x) def to_fixed_len_ids(text, tokenizer, seqlen, pad_mode="random", rnd=None): if rnd is None: rnd = random.Random() ids = tokenizer.encode(text, add_special_tokens=False) V = tokenizer.vocab_size if len(ids) >= seqlen: ids = ids[:seqlen] else: need = seqlen - len(ids) if pad_mode == "eos" and tokenizer.eos_token_id is not None: ids = ids + [tokenizer.eos_token_id] * need else: ids = ids + [rnd.randrange(V) for _ in range(need)] return torch.tensor(ids, dtype=torch.long).unsqueeze(0) def apply_noise_ops(x, tokenizer, indices_csv, add_noise_left, add_noise_right, seqlen, seed=0): rnd = random.Random(seed) V = tokenizer.vocab_size x = x.clone() idxs = set() if indices_csv and indices_csv.strip(): for part in indices_csv.split(","): part = part.strip() if not part: continue if "-" in part: a, b = part.split("-", 1) try: a, b = int(a), int(b) for j in range(min(a,b), max(a,b)+1): idxs.add(j) except: pass else: try: idxs.add(int(part)) except: pass for j in idxs: if 0 <= j < x.shape[1]: x[0, j] = rnd.randrange(V) if add_noise_left > 0: prefix = torch.tensor([rnd.randrange(V) for _ in range(int(add_noise_left))], dtype=torch.long).unsqueeze(0) x = torch.cat([prefix, x], dim=1) if add_noise_right > 0: suffix = torch.tensor([rnd.randrange(V) for _ in range(int(add_noise_right))], dtype=torch.long).unsqueeze(0) x = torch.cat([x, suffix], dim=1) if x.shape[1] > seqlen: x = x[:, :seqlen] elif x.shape[1] < seqlen: need = seqlen - x.shape[1] pad = torch.tensor([rnd.randrange(V) for _ in range(need)], dtype=torch.long).unsqueeze(0) x = torch.cat([x, pad], dim=1) return x @torch.no_grad() def sample_from_logits(logits_row, temperature=1.0, current_token=None, exclude_current=True): if temperature <= 0: return int(torch.argmax(logits_row).item()) scaled = logits_row / float(temperature) probs = torch.softmax(scaled, dim=-1) if exclude_current and current_token is not None: probs = probs.clone() probs[current_token] = 0.0 s = probs.sum() if s.item() <= 0: return int(torch.argmax(logits_row).item()) probs = probs / s return int(torch.multinomial(probs, 1).item()) # =============================================== # Weight loading (file / folder / HF Hub) # =============================================== DEFAULT_CKPT = os.environ.get("CKPT_PATH", "ckpt_latest.pt") DEFAULT_WEIGHTS_DIR = os.environ.get("WEIGHTS_DIR", "weights_latest") def _read_config_from_dict_or_infer(state, cfg): merged = {**DEFAULT_CONF, **(cfg or {})} if "tok_emb.weight" in state: merged["embed_dim"] = state["tok_emb.weight"].shape[1] block_idxs = [int(m.group(1)) for k in state.keys() for m in [re.match(r"blocks\.(\d+)\.", k)] if m] if block_idxs: merged["num_blocks"] = max(block_idxs) + 1 if "blocks.0.mlp.0.weight" in state or "blocks.0.mlp.2.weight" in state: merged["expansion_factor"] = infer_expansion_factor_from_state(state, merged["embed_dim"]) if not merged.get("tokenizer_name"): merged["tokenizer_name"] = "gpt2" return merged def _is_state_dict(obj): if isinstance(obj, dict) and obj: sample_val = next(iter(obj.values())) return isinstance(sample_val, torch.Tensor) return False def _load_state_from_pt(path: str): obj = torch.load(path, map_location="cpu") if isinstance(obj, dict) and "model" in obj and isinstance(obj["model"], dict): state = obj["model"] cfg = obj.get("config", {}) or {} if "tokenizer_name" in obj: cfg = {**cfg, "tokenizer_name": obj["tokenizer_name"]} return state, cfg if _is_state_dict(obj): return obj, {} raise ValueError(f"Unsupported .pt format at {path}: expected a state_dict or a payload with 'model'.") def _merge_state_dicts(dicts): merged = {} for d in dicts: for k, v in d.items(): merged[k] = v return merged def _load_state_from_folder(weights_dir: str): if not os.path.isdir(weights_dir): raise FileNotFoundError(f"Folder not found: {weights_dir}") cfg_path = os.path.join(weights_dir, "config.json") cfg = {} if os.path.exists(cfg_path): with open(cfg_path, "r") as f: cfg = json.load(f) files = sorted(os.listdir(weights_dir)) sft_files = [f for f in files if f.endswith(".safetensors")] pt_files = [f for f in files if f.endswith(".pt") or f.endswith(".bin")] state = None if "model.safetensors" in sft_files: state = load_sft(os.path.join(weights_dir, "model.safetensors")) elif sft_files: parts = [load_sft(os.path.join(weights_dir, f)) for f in sft_files] state = _merge_state_dicts(parts) elif pt_files: parts = [] for f in pt_files: part = torch.load(os.path.join(weights_dir, f), map_location="cpu") if isinstance(part, dict) and "model" in part and isinstance(part["model"], dict): parts.append(part["model"]) if "config" in part and isinstance(part["config"], dict): cfg = {**cfg, **part["config"]} if "tokenizer_name" in part: cfg.setdefault("tokenizer_name", part["tokenizer_name"]) elif _is_state_dict(part): parts.append(part) else: raise ValueError(f"Unsupported shard format: {f}") state = _merge_state_dicts(parts) else: raise FileNotFoundError( f"No weights found in {weights_dir}. Expected .safetensors or .pt files." ) return state, cfg def _load_state_from_hub(repo_id: str, subfolder: str | None = None, revision: str | None = None): cache_dir = snapshot_download(repo_id=repo_id, revision=revision, allow_patterns=None) path = os.path.join(cache_dir, subfolder) if subfolder else cache_dir return _load_state_from_folder(path) def load_model(source: str): src = source or "" state, cfg = None, {} if os.path.isfile(src) and (src.endswith(".pt") or src.endswith(".bin")): state, cfg = _load_state_from_pt(src) elif os.path.isdir(src): state, cfg = _load_state_from_folder(src) elif "/" in src: # Hub repo id subfolder = os.environ.get("WEIGHTS_SUBFOLDER") or None revision = os.environ.get("WEIGHTS_REVISION") or None state, cfg = _load_state_from_hub(src, subfolder=subfolder, revision=revision) else: # fallbacks if os.path.isfile("weights_latest.pt"): state, cfg = _load_state_from_pt("weights_latest.pt") elif os.path.isfile(DEFAULT_CKPT): state, cfg = _load_state_from_pt(DEFAULT_CKPT) elif os.path.isdir(DEFAULT_WEIGHTS_DIR): state, cfg = _load_state_from_folder(DEFAULT_WEIGHTS_DIR) else: raise FileNotFoundError( f"Could not resolve weights from '{src}'. Tried file (.pt), folder, hub repo id, " f"then defaults ('{DEFAULT_CKPT}', '{DEFAULT_WEIGHTS_DIR}')." ) conf = _read_config_from_dict_or_infer(state, cfg) # Tokenizer tokenizer = AutoTokenizer.from_pretrained(conf["tokenizer_name"], use_fast=True) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token tokenizer.model_max_length = 1_000_000_000 vocab_size = tokenizer.vocab_size # Build model model = CNA( conf["embed_dim"], conf["num_heads"], conf["expansion_factor"], conf["num_blocks"], conf["radius"], vocab_size ) # Load state (tolerate projection size mismatch) missing, unexpected = model.load_state_dict(state, strict=False) if any(k.startswith("proj.") for k in missing): with torch.no_grad(): nn.init.normal_(model.proj.weight, std=0.02) nn.init.zeros_(model.proj.bias) else: model.load_state_dict(state, strict=True) # enforce float32 across params & buffers model = model.to(torch.float32) with torch.no_grad(): for p in model.parameters(): if p.dtype.is_floating_point: p.data = p.data.float() for _, buf in model.named_buffers(): if buf.dtype.is_floating_point: buf.data = buf.data.float() model.eval() return model, tokenizer, conf["radius"] model_cache = {"model": None, "tokenizer": None, "radius": None, "ckpt": None} def _auto_default_source(): env = os.environ.get("WEIGHTS_SOURCE") if env: return env if os.path.isdir("weights_latest"): return "weights_latest" for name in ["weights_latest.pt", "ckpt_latest.pt"]: if os.path.isfile(name): return name for f in sorted(os.listdir(".")): if f.endswith(".pt") or f.endswith(".safetensors"): return f return "weights_latest.pt" def ensure_model(source_path_or_repo): src = source_path_or_repo or _auto_default_source() if model_cache["model"] is None or model_cache["ckpt"] != src: m, tok, rad = load_model(src) model_cache.update({"model": m, "tokenizer": tok, "radius": rad, "ckpt": src}) # =============================================== # Strategy 1 (random position) with argmax / sample # =============================================== @torch.no_grad() def step_strategy1(model, x, mode="argmax", temperature=1.0, exclude_current=True): S = x.shape[1] pos = int(torch.randint(0, S, (1,)).item()) logits_pos = model_logits(model, x)[0, pos] if mode == "sample": cur_tok = int(x[0, pos].item()) new_tok = sample_from_logits(logits_pos, temperature=float(temperature), current_token=cur_tok, exclude_current=bool(exclude_current)) x[0, pos] = new_tok else: x[0, pos] = int(torch.argmax(logits_pos).item()) return x # =============================================== # Gradio callbacks # =============================================== def init_random(src, seqlen, seed): ensure_model(src) random.seed(seed); torch.manual_seed(seed) V = model_cache["tokenizer"].vocab_size x = torch.randint(0, V, (1, int(seqlen))) txt = decode(x[0], model_cache["tokenizer"]) return x.tolist(), txt, f"Initialized random sequence (len={int(seqlen)})" def to_ranges(indices): """Compress a sorted list of token indices into 'a-b' CSV.""" if not indices: return "" indices = sorted(set(indices)) ranges = [] start = prev = indices[0] for i in indices[1:]: if i == prev + 1: prev = i else: ranges.append((start, prev)) start = prev = i ranges.append((start, prev)) parts = [f"{a}-{b}" if a != b else f"{a}" for a, b in ranges] return ", ".join(parts) def capture_selection(text, seqlen, current_ids, evt: gr.SelectData | None = None): """ Map highlighted character span in `text` to token index ranges using tokenizer offsets. Auto-fills the indices box so you can 'Noise Selection'. """ ensure_model(None) tok = model_cache["tokenizer"] if not text: return gr.update(), "No text to select from." # Try to read (start, end) from the event payload start, end = None, None if evt is not None: try: # gradio SelectData for Textbox exposes .index = (start_char, end_char) start, end = evt.index except Exception: pass # Fallback: nothing selected if start is None or end is None or start == end: return gr.update(), "No selection detected (drag to highlight)." # Bound the indices defensively start = max(0, min(len(text), int(start))) end = max(0, min(len(text), int(end))) # Get per-token char offsets from the fast tokenizer enc = tok(text, add_special_tokens=False, return_offsets_mapping=True) offsets = enc["offset_mapping"] # list of (s,e) per token token_idxs = [] for i, (s, e) in enumerate(offsets): if s is None or e is None: continue # overlap if token span intersects [start, end) if max(s, start) < min(e, end): token_idxs.append(i) if not token_idxs: return gr.update(), "Selection didn't hit any tokens (maybe whitespace)." # Clip to current sequence length (so we don't index beyond S) S = int(seqlen) token_idxs = [i for i in token_idxs if i < S] if not token_idxs: return gr.update(), "Selected span maps beyond current sequence length." indices_csv = to_ranges(token_idxs) return indices_csv, f"Selected chars [{start}:{end}) → tokens {indices_csv}" def noise_selection(src, state_ids, seqlen, indices_csv, seed): # Reuse apply_noise but force prepend/append noise to zero return apply_noise(src, state_ids, seqlen, indices_csv, 0, 0, seed) def apply_noise(src, state_ids, seqlen, indices_csv, add_left, add_right, seed): ensure_model(src) tok = model_cache["tokenizer"] S = int(seqlen) if state_ids is None or len(state_ids) == 0: V = tok.vocab_size base = torch.randint(0, V, (1, S)) else: base = to_batch2(state_ids) x = apply_noise_ops(base, tok, indices_csv, int(add_left or 0), int(add_right or 0), S, seed=seed) txt = decode(x[0], tok) return x.tolist(), txt, "Applied noise" def step_once(src, state_ids, mode, temperature, exclude_current): ensure_model(src) tok = model_cache["tokenizer"] if state_ids is None or len(state_ids) == 0: return None, "", "No sequence to step — initialize first." x = to_batch2(state_ids) x = step_strategy1(model_cache["model"], x, mode=mode, temperature=temperature, exclude_current=exclude_current) txt = decode(x[0], tok) return x.tolist(), txt, f"Stepped 1 iteration ({mode})" def live_denoise(src, state_ids, steps, snap_every, seed, mode, temperature, exclude_current): ensure_model(src) tok = model_cache["tokenizer"] if state_ids is None or len(state_ids) == 0: return random.seed(seed); torch.manual_seed(seed) x = to_batch2(state_ids) total = int(steps); snap = max(1, int(snap_every)) for t in range(1, total + 1): x = step_strategy1(model_cache["model"], x, mode=mode, temperature=temperature, exclude_current=exclude_current) if (t % snap == 0) or (t == total): txt = decode(x[0], tok) yield x.tolist(), txt, f"Live denoise… step {t}/{total} ({mode})" # =============================================== # UI (single mode) # =============================================== with gr.Blocks(title="Self Organising Text Demo") as demo: gr.Markdown( """ # Self Organising Text Demo Watch text self organise using only local attention. """ ) default_source = os.environ.get("WEIGHTS_SOURCE", None) if default_source is None: default_source = _auto_default_source() with gr.Row(): src = gr.Textbox(value=default_source, label="Weights (file / folder / HF repo id)") seqlen = gr.Slider(10, 512, value=50, step=1, label="Sequence length (S)") seed = gr.Slider(0, 10000, value=0, step=1, label="Seed") ids_state = gr.State(value=None) with gr.Row(): current_text = gr.Textbox(lines=8, label="Current text", interactive=True) status = gr.Markdown("Ready.") gr.Markdown("### Initialize & Denoise") with gr.Row(): btn_random = gr.Button("Initialize Random") steps = gr.Slider(1, 2000, value=100, step=1, label="Denoise steps (N)") # default 100 snap_every = gr.Slider(1, 100, value=1, step=1, label="Update every K steps") # default 1 with gr.Row(): update_mode = gr.Radio( choices=["argmax", "sample"], value="sample", # default to sampling label="Update rule" ) temperature = gr.Slider(minimum=0.0, maximum=5.0, value=1.0, step=0.05, label="Temperature (sampling)") exclude_current = gr.Checkbox(value=True, label="Exclude current token when sampling") with gr.Row(): btn_step_once = gr.Button("Step Once") btn_live = gr.Button("Denoise Live (streaming)") gr.Markdown("### Noise by Indices") with gr.Row(): indices_csv = gr.Textbox( label="Positions to noise (enter like: 0, 5, 10-20)", placeholder="e.g., 0, 5, 10-20" ) with gr.Row(): add_left = gr.Number(value=0, precision=0, label="Noise tokens to add at START") add_right = gr.Number(value=0, precision=0, label="Noise tokens to add at END") btn_apply_noise = gr.Button("Apply Noise") # --- Wiring --- btn_random.click(init_random, [src, seqlen, seed], [ids_state, current_text, status]) # Manual indices + prepend/append noise btn_apply_noise.click( apply_noise, [src, ids_state, seqlen, indices_csv, add_left, add_right, seed], [ids_state, current_text, status] ) btn_step_once.click( step_once, [src, ids_state, update_mode, temperature, exclude_current], [ids_state, current_text, status] ) btn_live.click( live_denoise, [src, ids_state, steps, snap_every, seed, update_mode, temperature, exclude_current], [ids_state, current_text, status], show_progress=True ) demo.queue().launch()