Spaces:
Sleeping
Sleeping
| # app.py | |
| import os, re, math, random, json | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import gradio as gr | |
| from transformers import AutoTokenizer | |
| from safetensors.torch import load_file as load_sft | |
| from huggingface_hub import snapshot_download | |
| torch.set_default_dtype(torch.float32) | |
| # =============================================== | |
| # Default config (from your training notes) | |
| # =============================================== | |
| DEFAULT_CONF = { | |
| "embed_dim": 1024, | |
| "num_heads": 8, | |
| "expansion_factor": 4, | |
| "num_blocks": 8, | |
| "radius": 16, | |
| "tokenizer_name": "gpt2", | |
| } | |
| # =============================================== | |
| # Minimal CNA (inference-ready) | |
| # =============================================== | |
| class AttnBlock(nn.Module): | |
| def __init__(self, embed_dim, num_heads, expansion_factor): | |
| super().__init__() | |
| assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads" | |
| self.embed_dim = embed_dim | |
| self.num_heads = num_heads | |
| self.head_dim = embed_dim // num_heads | |
| self.norm1 = nn.LayerNorm(embed_dim) | |
| self.QKV = nn.Linear(embed_dim, embed_dim * 3) | |
| self.Wo = nn.Linear(embed_dim, embed_dim) | |
| self.norm2 = nn.LayerNorm(embed_dim) | |
| self.mlp = nn.Sequential( | |
| nn.Linear(embed_dim, embed_dim * expansion_factor), | |
| nn.GELU(), | |
| nn.Linear(embed_dim * expansion_factor, embed_dim), | |
| ) | |
| # zero-init residual branches (match training) | |
| nn.init.zeros_(self.Wo.weight); nn.init.zeros_(self.Wo.bias) | |
| nn.init.zeros_(self.mlp[-1].weight); nn.init.zeros_(self.mlp[-1].bias) | |
| def rope(self, Qh, Kh_seq, cos, sin): | |
| Qe = Qh[..., 0::2]; Qo = Qh[..., 1::2] | |
| ce = cos[..., 0::2]; se = sin[..., 0::2] | |
| Qr_e = Qe * ce - Qo * se | |
| Qr_o = Qe * se + Qo * ce | |
| Qh2 = torch.empty_like(Qh); Qh2[..., 0::2] = Qr_e; Qh2[..., 1::2] = Qr_o | |
| Ke = Kh_seq[..., 0::2]; Ko = Kh_seq[..., 1::2] | |
| Kr_e = Ke * ce - Ko * se | |
| Kr_o = Ke * se + Ko * ce | |
| Kh2 = torch.empty_like(Kh_seq); Kh2[..., 0::2] = Kr_e; Kh2[..., 1::2] = Kr_o | |
| return Qh2, Kh2 | |
| def forward(self, x, rope, radius): | |
| # keep LN inputs & params same dtype | |
| if x.dtype != self.norm1.weight.dtype: | |
| x = x.to(self.norm1.weight.dtype) | |
| h = self.norm1(x) | |
| B, S, E = h.shape | |
| cos, sin = rope | |
| nh, hd = self.num_heads, self.head_dim | |
| cos = cos.to(h.dtype).to(h.device).permute(0,2,1,3) # [1,1,S,hd] | |
| sin = sin.to(h.dtype).to(h.device).permute(0,2,1,3) | |
| # local band mask | |
| idx = torch.arange(S, device=h.device) | |
| idx_dist = (idx.view(1, S) - idx.view(S, 1)).abs() | |
| neg_inf = torch.finfo(h.dtype).min | |
| mask = torch.full((S, S), neg_inf, dtype=h.dtype, device=h.device) | |
| mask[idx_dist <= int(radius)] = 0 | |
| mask = mask.view(1, 1, S, S) | |
| qkv = self.QKV(h) | |
| q, k, v = qkv.chunk(3, dim=-1) | |
| Qh = q.view(B,S,nh,hd).permute(0,2,1,3).contiguous() | |
| Kh_seq = k.view(B,S,nh,hd).permute(0,2,1,3).contiguous() | |
| Vh = v.view(B,S,nh,hd).permute(0,2,1,3).contiguous() | |
| assert hd % 2 == 0, "rope needs even head_dim" | |
| Qh, Kh_seq = self.rope(Qh, Kh_seq, cos, sin) | |
| Kh = Kh_seq.permute(0,1,3,2).contiguous() | |
| logits = (Qh @ Kh) * (hd ** -0.5) | |
| attn = F.softmax(logits + mask, dim=-1) @ Vh | |
| attn = attn.permute(0,2,1,3).contiguous().view(B,S,E) | |
| x = x + self.Wo(attn) | |
| x = x + self.mlp(self.norm2(x)) | |
| return x | |
| class CNA(nn.Module): | |
| def __init__(self, embed_dim, num_heads, expansion_factor, num_blocks, radius, vocab_size): | |
| super().__init__() | |
| self.embed_dim = embed_dim | |
| self.num_heads = num_heads | |
| self.expansion_factor = expansion_factor | |
| self.num_blocks = num_blocks | |
| self.vocab_size = vocab_size | |
| self.radius = radius | |
| self.tok_emb = nn.Embedding(vocab_size, embed_dim) | |
| self.blocks = nn.ModuleList([AttnBlock(embed_dim, num_heads, expansion_factor) for _ in range(num_blocks)]) | |
| self.proj = nn.Linear(embed_dim, vocab_size) | |
| def _rope_seq(self, S, hd, device, dtype, base=10000.0): | |
| pos = torch.arange(S, device=device, dtype=dtype) | |
| half = hd // 2 | |
| idx = torch.arange(half, device=device, dtype=dtype) | |
| inv = base ** (-idx / half) | |
| ang = pos[:, None] * inv[None, :] | |
| cos = ang.cos().unsqueeze(0).unsqueeze(2) | |
| sin = ang.sin().unsqueeze(0).unsqueeze(2) | |
| cos = torch.stack((cos, cos), dim=-1).reshape(1, S, 1, hd) | |
| sin = torch.stack((sin, sin), dim=-1).reshape(1, S, 1, hd) | |
| return cos, sin | |
| def forward(self, x): | |
| if x.dtype == torch.long and x.dim() == 2: | |
| h = self.tok_emb(x) | |
| else: | |
| h = x | |
| # ensure embeddings/activations dtype follows model dtype | |
| target_dtype = next(self.parameters()).dtype | |
| if h.dtype != target_dtype: | |
| h = h.to(target_dtype) | |
| B, S, E = h.shape | |
| hd = self.embed_dim // self.num_heads | |
| cos, sin = self._rope_seq(S, hd, h.device, h.dtype) | |
| for blk in self.blocks: | |
| h = blk(h, rope=(cos, sin), radius=self.radius) | |
| return self.proj(h) | |
| # =============================================== | |
| # Helpers | |
| # =============================================== | |
| def to_batch2(ids_like) -> torch.Tensor: | |
| """ | |
| Normalize ids_like (list, [[...]], tensor) to int64 shape [1, S]. | |
| Accepts [S], [1,S], [1,1,S]; returns [1,S]. | |
| """ | |
| x = torch.tensor(ids_like, dtype=torch.long) | |
| if x.dim() == 1: | |
| x = x.unsqueeze(0) # [S] -> [1,S] | |
| elif x.dim() == 3 and x.shape[0] == 1 and x.shape[1] == 1: | |
| x = x.squeeze(1) # [1,1,S] -> [1,S] | |
| elif x.dim() != 2: | |
| x = x.view(1, -1) # fallback reshape | |
| return x | |
| def infer_expansion_factor_from_state(state, embed_dim): | |
| for key in ("blocks.0.mlp.0.weight", "blocks.0.mlp.2.weight"): | |
| if key in state: | |
| W = state[key] | |
| if key.endswith("0.weight"): | |
| return int(W.shape[0] // embed_dim) | |
| else: | |
| return int(W.shape[1] // embed_dim) | |
| return DEFAULT_CONF["expansion_factor"] | |
| def decode(ids, tokenizer, max_chars=1000): | |
| s = tokenizer.decode(ids.tolist(), skip_special_tokens=True) | |
| s = s.replace("\n", " ") | |
| return s[:max_chars] + ("…" if len(s) > max_chars else "") | |
| def model_logits(model, x): | |
| return model(x) | |
| def to_fixed_len_ids(text, tokenizer, seqlen, pad_mode="random", rnd=None): | |
| if rnd is None: | |
| rnd = random.Random() | |
| ids = tokenizer.encode(text, add_special_tokens=False) | |
| V = tokenizer.vocab_size | |
| if len(ids) >= seqlen: | |
| ids = ids[:seqlen] | |
| else: | |
| need = seqlen - len(ids) | |
| if pad_mode == "eos" and tokenizer.eos_token_id is not None: | |
| ids = ids + [tokenizer.eos_token_id] * need | |
| else: | |
| ids = ids + [rnd.randrange(V) for _ in range(need)] | |
| return torch.tensor(ids, dtype=torch.long).unsqueeze(0) | |
| def apply_noise_ops(x, tokenizer, indices_csv, add_noise_left, add_noise_right, seqlen, seed=0): | |
| rnd = random.Random(seed) | |
| V = tokenizer.vocab_size | |
| x = x.clone() | |
| idxs = set() | |
| if indices_csv and indices_csv.strip(): | |
| for part in indices_csv.split(","): | |
| part = part.strip() | |
| if not part: continue | |
| if "-" in part: | |
| a, b = part.split("-", 1) | |
| try: | |
| a, b = int(a), int(b) | |
| for j in range(min(a,b), max(a,b)+1): | |
| idxs.add(j) | |
| except: | |
| pass | |
| else: | |
| try: | |
| idxs.add(int(part)) | |
| except: | |
| pass | |
| for j in idxs: | |
| if 0 <= j < x.shape[1]: | |
| x[0, j] = rnd.randrange(V) | |
| if add_noise_left > 0: | |
| prefix = torch.tensor([rnd.randrange(V) for _ in range(int(add_noise_left))], dtype=torch.long).unsqueeze(0) | |
| x = torch.cat([prefix, x], dim=1) | |
| if add_noise_right > 0: | |
| suffix = torch.tensor([rnd.randrange(V) for _ in range(int(add_noise_right))], dtype=torch.long).unsqueeze(0) | |
| x = torch.cat([x, suffix], dim=1) | |
| if x.shape[1] > seqlen: | |
| x = x[:, :seqlen] | |
| elif x.shape[1] < seqlen: | |
| need = seqlen - x.shape[1] | |
| pad = torch.tensor([rnd.randrange(V) for _ in range(need)], dtype=torch.long).unsqueeze(0) | |
| x = torch.cat([x, pad], dim=1) | |
| return x | |
| def sample_from_logits(logits_row, temperature=1.0, current_token=None, exclude_current=True): | |
| if temperature <= 0: | |
| return int(torch.argmax(logits_row).item()) | |
| scaled = logits_row / float(temperature) | |
| probs = torch.softmax(scaled, dim=-1) | |
| if exclude_current and current_token is not None: | |
| probs = probs.clone() | |
| probs[current_token] = 0.0 | |
| s = probs.sum() | |
| if s.item() <= 0: | |
| return int(torch.argmax(logits_row).item()) | |
| probs = probs / s | |
| return int(torch.multinomial(probs, 1).item()) | |
| # =============================================== | |
| # Weight loading (file / folder / HF Hub) | |
| # =============================================== | |
| DEFAULT_CKPT = os.environ.get("CKPT_PATH", "ckpt_latest.pt") | |
| DEFAULT_WEIGHTS_DIR = os.environ.get("WEIGHTS_DIR", "weights_latest") | |
| def _read_config_from_dict_or_infer(state, cfg): | |
| merged = {**DEFAULT_CONF, **(cfg or {})} | |
| if "tok_emb.weight" in state: | |
| merged["embed_dim"] = state["tok_emb.weight"].shape[1] | |
| block_idxs = [int(m.group(1)) for k in state.keys() for m in [re.match(r"blocks\.(\d+)\.", k)] if m] | |
| if block_idxs: | |
| merged["num_blocks"] = max(block_idxs) + 1 | |
| if "blocks.0.mlp.0.weight" in state or "blocks.0.mlp.2.weight" in state: | |
| merged["expansion_factor"] = infer_expansion_factor_from_state(state, merged["embed_dim"]) | |
| if not merged.get("tokenizer_name"): | |
| merged["tokenizer_name"] = "gpt2" | |
| return merged | |
| def _is_state_dict(obj): | |
| if isinstance(obj, dict) and obj: | |
| sample_val = next(iter(obj.values())) | |
| return isinstance(sample_val, torch.Tensor) | |
| return False | |
| def _load_state_from_pt(path: str): | |
| obj = torch.load(path, map_location="cpu") | |
| if isinstance(obj, dict) and "model" in obj and isinstance(obj["model"], dict): | |
| state = obj["model"] | |
| cfg = obj.get("config", {}) or {} | |
| if "tokenizer_name" in obj: | |
| cfg = {**cfg, "tokenizer_name": obj["tokenizer_name"]} | |
| return state, cfg | |
| if _is_state_dict(obj): | |
| return obj, {} | |
| raise ValueError(f"Unsupported .pt format at {path}: expected a state_dict or a payload with 'model'.") | |
| def _merge_state_dicts(dicts): | |
| merged = {} | |
| for d in dicts: | |
| for k, v in d.items(): | |
| merged[k] = v | |
| return merged | |
| def _load_state_from_folder(weights_dir: str): | |
| if not os.path.isdir(weights_dir): | |
| raise FileNotFoundError(f"Folder not found: {weights_dir}") | |
| cfg_path = os.path.join(weights_dir, "config.json") | |
| cfg = {} | |
| if os.path.exists(cfg_path): | |
| with open(cfg_path, "r") as f: | |
| cfg = json.load(f) | |
| files = sorted(os.listdir(weights_dir)) | |
| sft_files = [f for f in files if f.endswith(".safetensors")] | |
| pt_files = [f for f in files if f.endswith(".pt") or f.endswith(".bin")] | |
| state = None | |
| if "model.safetensors" in sft_files: | |
| state = load_sft(os.path.join(weights_dir, "model.safetensors")) | |
| elif sft_files: | |
| parts = [load_sft(os.path.join(weights_dir, f)) for f in sft_files] | |
| state = _merge_state_dicts(parts) | |
| elif pt_files: | |
| parts = [] | |
| for f in pt_files: | |
| part = torch.load(os.path.join(weights_dir, f), map_location="cpu") | |
| if isinstance(part, dict) and "model" in part and isinstance(part["model"], dict): | |
| parts.append(part["model"]) | |
| if "config" in part and isinstance(part["config"], dict): | |
| cfg = {**cfg, **part["config"]} | |
| if "tokenizer_name" in part: | |
| cfg.setdefault("tokenizer_name", part["tokenizer_name"]) | |
| elif _is_state_dict(part): | |
| parts.append(part) | |
| else: | |
| raise ValueError(f"Unsupported shard format: {f}") | |
| state = _merge_state_dicts(parts) | |
| else: | |
| raise FileNotFoundError( | |
| f"No weights found in {weights_dir}. Expected .safetensors or .pt files." | |
| ) | |
| return state, cfg | |
| def _load_state_from_hub(repo_id: str, subfolder: str | None = None, revision: str | None = None): | |
| cache_dir = snapshot_download(repo_id=repo_id, revision=revision, allow_patterns=None) | |
| path = os.path.join(cache_dir, subfolder) if subfolder else cache_dir | |
| return _load_state_from_folder(path) | |
| def load_model(source: str): | |
| src = source or "" | |
| state, cfg = None, {} | |
| if os.path.isfile(src) and (src.endswith(".pt") or src.endswith(".bin")): | |
| state, cfg = _load_state_from_pt(src) | |
| elif os.path.isdir(src): | |
| state, cfg = _load_state_from_folder(src) | |
| elif "/" in src: # Hub repo id | |
| subfolder = os.environ.get("WEIGHTS_SUBFOLDER") or None | |
| revision = os.environ.get("WEIGHTS_REVISION") or None | |
| state, cfg = _load_state_from_hub(src, subfolder=subfolder, revision=revision) | |
| else: | |
| # fallbacks | |
| if os.path.isfile("weights_latest.pt"): | |
| state, cfg = _load_state_from_pt("weights_latest.pt") | |
| elif os.path.isfile(DEFAULT_CKPT): | |
| state, cfg = _load_state_from_pt(DEFAULT_CKPT) | |
| elif os.path.isdir(DEFAULT_WEIGHTS_DIR): | |
| state, cfg = _load_state_from_folder(DEFAULT_WEIGHTS_DIR) | |
| else: | |
| raise FileNotFoundError( | |
| f"Could not resolve weights from '{src}'. Tried file (.pt), folder, hub repo id, " | |
| f"then defaults ('{DEFAULT_CKPT}', '{DEFAULT_WEIGHTS_DIR}')." | |
| ) | |
| conf = _read_config_from_dict_or_infer(state, cfg) | |
| # Tokenizer | |
| tokenizer = AutoTokenizer.from_pretrained(conf["tokenizer_name"], use_fast=True) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| tokenizer.model_max_length = 1_000_000_000 | |
| vocab_size = tokenizer.vocab_size | |
| # Build model | |
| model = CNA( | |
| conf["embed_dim"], conf["num_heads"], conf["expansion_factor"], | |
| conf["num_blocks"], conf["radius"], vocab_size | |
| ) | |
| # Load state (tolerate projection size mismatch) | |
| missing, unexpected = model.load_state_dict(state, strict=False) | |
| if any(k.startswith("proj.") for k in missing): | |
| with torch.no_grad(): | |
| nn.init.normal_(model.proj.weight, std=0.02) | |
| nn.init.zeros_(model.proj.bias) | |
| else: | |
| model.load_state_dict(state, strict=True) | |
| # enforce float32 across params & buffers | |
| model = model.to(torch.float32) | |
| with torch.no_grad(): | |
| for p in model.parameters(): | |
| if p.dtype.is_floating_point: | |
| p.data = p.data.float() | |
| for _, buf in model.named_buffers(): | |
| if buf.dtype.is_floating_point: | |
| buf.data = buf.data.float() | |
| model.eval() | |
| return model, tokenizer, conf["radius"] | |
| model_cache = {"model": None, "tokenizer": None, "radius": None, "ckpt": None} | |
| def _auto_default_source(): | |
| env = os.environ.get("WEIGHTS_SOURCE") | |
| if env: | |
| return env | |
| if os.path.isdir("weights_latest"): | |
| return "weights_latest" | |
| for name in ["weights_latest.pt", "ckpt_latest.pt"]: | |
| if os.path.isfile(name): | |
| return name | |
| for f in sorted(os.listdir(".")): | |
| if f.endswith(".pt") or f.endswith(".safetensors"): | |
| return f | |
| return "weights_latest.pt" | |
| def ensure_model(source_path_or_repo): | |
| src = source_path_or_repo or _auto_default_source() | |
| if model_cache["model"] is None or model_cache["ckpt"] != src: | |
| m, tok, rad = load_model(src) | |
| model_cache.update({"model": m, "tokenizer": tok, "radius": rad, "ckpt": src}) | |
| # =============================================== | |
| # Strategy 1 (random position) with argmax / sample | |
| # =============================================== | |
| def step_strategy1(model, x, mode="argmax", temperature=1.0, exclude_current=True): | |
| S = x.shape[1] | |
| pos = int(torch.randint(0, S, (1,)).item()) | |
| logits_pos = model_logits(model, x)[0, pos] | |
| if mode == "sample": | |
| cur_tok = int(x[0, pos].item()) | |
| new_tok = sample_from_logits(logits_pos, temperature=float(temperature), | |
| current_token=cur_tok, exclude_current=bool(exclude_current)) | |
| x[0, pos] = new_tok | |
| else: | |
| x[0, pos] = int(torch.argmax(logits_pos).item()) | |
| return x | |
| # =============================================== | |
| # Gradio callbacks | |
| # =============================================== | |
| def init_random(src, seqlen, seed): | |
| ensure_model(src) | |
| random.seed(seed); torch.manual_seed(seed) | |
| V = model_cache["tokenizer"].vocab_size | |
| x = torch.randint(0, V, (1, int(seqlen))) | |
| txt = decode(x[0], model_cache["tokenizer"]) | |
| return x.tolist(), txt, f"Initialized random sequence (len={int(seqlen)})" | |
| def to_ranges(indices): | |
| """Compress a sorted list of token indices into 'a-b' CSV.""" | |
| if not indices: | |
| return "" | |
| indices = sorted(set(indices)) | |
| ranges = [] | |
| start = prev = indices[0] | |
| for i in indices[1:]: | |
| if i == prev + 1: | |
| prev = i | |
| else: | |
| ranges.append((start, prev)) | |
| start = prev = i | |
| ranges.append((start, prev)) | |
| parts = [f"{a}-{b}" if a != b else f"{a}" for a, b in ranges] | |
| return ", ".join(parts) | |
| def capture_selection(text, seqlen, current_ids, evt: gr.SelectData | None = None): | |
| """ | |
| Map highlighted character span in `text` to token index ranges using tokenizer offsets. | |
| Auto-fills the indices box so you can 'Noise Selection'. | |
| """ | |
| ensure_model(None) | |
| tok = model_cache["tokenizer"] | |
| if not text: | |
| return gr.update(), "No text to select from." | |
| # Try to read (start, end) from the event payload | |
| start, end = None, None | |
| if evt is not None: | |
| try: | |
| # gradio SelectData for Textbox exposes .index = (start_char, end_char) | |
| start, end = evt.index | |
| except Exception: | |
| pass | |
| # Fallback: nothing selected | |
| if start is None or end is None or start == end: | |
| return gr.update(), "No selection detected (drag to highlight)." | |
| # Bound the indices defensively | |
| start = max(0, min(len(text), int(start))) | |
| end = max(0, min(len(text), int(end))) | |
| # Get per-token char offsets from the fast tokenizer | |
| enc = tok(text, add_special_tokens=False, return_offsets_mapping=True) | |
| offsets = enc["offset_mapping"] # list of (s,e) per token | |
| token_idxs = [] | |
| for i, (s, e) in enumerate(offsets): | |
| if s is None or e is None: | |
| continue | |
| # overlap if token span intersects [start, end) | |
| if max(s, start) < min(e, end): | |
| token_idxs.append(i) | |
| if not token_idxs: | |
| return gr.update(), "Selection didn't hit any tokens (maybe whitespace)." | |
| # Clip to current sequence length (so we don't index beyond S) | |
| S = int(seqlen) | |
| token_idxs = [i for i in token_idxs if i < S] | |
| if not token_idxs: | |
| return gr.update(), "Selected span maps beyond current sequence length." | |
| indices_csv = to_ranges(token_idxs) | |
| return indices_csv, f"Selected chars [{start}:{end}) → tokens {indices_csv}" | |
| def noise_selection(src, state_ids, seqlen, indices_csv, seed): | |
| # Reuse apply_noise but force prepend/append noise to zero | |
| return apply_noise(src, state_ids, seqlen, indices_csv, 0, 0, seed) | |
| def apply_noise(src, state_ids, seqlen, indices_csv, add_left, add_right, seed): | |
| ensure_model(src) | |
| tok = model_cache["tokenizer"] | |
| S = int(seqlen) | |
| if state_ids is None or len(state_ids) == 0: | |
| V = tok.vocab_size | |
| base = torch.randint(0, V, (1, S)) | |
| else: | |
| base = to_batch2(state_ids) | |
| x = apply_noise_ops(base, tok, indices_csv, int(add_left or 0), int(add_right or 0), S, seed=seed) | |
| txt = decode(x[0], tok) | |
| return x.tolist(), txt, "Applied noise" | |
| def step_once(src, state_ids, mode, temperature, exclude_current): | |
| ensure_model(src) | |
| tok = model_cache["tokenizer"] | |
| if state_ids is None or len(state_ids) == 0: | |
| return None, "", "No sequence to step — initialize first." | |
| x = to_batch2(state_ids) | |
| x = step_strategy1(model_cache["model"], x, mode=mode, temperature=temperature, exclude_current=exclude_current) | |
| txt = decode(x[0], tok) | |
| return x.tolist(), txt, f"Stepped 1 iteration ({mode})" | |
| def live_denoise(src, state_ids, steps, snap_every, seed, mode, temperature, exclude_current): | |
| ensure_model(src) | |
| tok = model_cache["tokenizer"] | |
| if state_ids is None or len(state_ids) == 0: | |
| return | |
| random.seed(seed); torch.manual_seed(seed) | |
| x = to_batch2(state_ids) | |
| total = int(steps); snap = max(1, int(snap_every)) | |
| for t in range(1, total + 1): | |
| x = step_strategy1(model_cache["model"], x, mode=mode, temperature=temperature, exclude_current=exclude_current) | |
| if (t % snap == 0) or (t == total): | |
| txt = decode(x[0], tok) | |
| yield x.tolist(), txt, f"Live denoise… step {t}/{total} ({mode})" | |
| # =============================================== | |
| # UI (single mode) | |
| # =============================================== | |
| with gr.Blocks(title="Self Organising Text Demo") as demo: | |
| gr.Markdown( | |
| """ | |
| # Self Organising Text Demo | |
| Watch text self organise using only local attention. | |
| """ | |
| ) | |
| default_source = os.environ.get("WEIGHTS_SOURCE", None) | |
| if default_source is None: | |
| default_source = _auto_default_source() | |
| with gr.Row(): | |
| src = gr.Textbox(value=default_source, label="Weights (file / folder / HF repo id)") | |
| seqlen = gr.Slider(10, 512, value=50, step=1, label="Sequence length (S)") | |
| seed = gr.Slider(0, 10000, value=0, step=1, label="Seed") | |
| ids_state = gr.State(value=None) | |
| with gr.Row(): | |
| current_text = gr.Textbox(lines=8, label="Current text", interactive=True) | |
| status = gr.Markdown("Ready.") | |
| gr.Markdown("### Initialize & Denoise") | |
| with gr.Row(): | |
| btn_random = gr.Button("Initialize Random") | |
| steps = gr.Slider(1, 2000, value=100, step=1, label="Denoise steps (N)") # default 100 | |
| snap_every = gr.Slider(1, 100, value=1, step=1, label="Update every K steps") # default 1 | |
| with gr.Row(): | |
| update_mode = gr.Radio( | |
| choices=["argmax", "sample"], | |
| value="sample", # default to sampling | |
| label="Update rule" | |
| ) | |
| temperature = gr.Slider(minimum=0.0, maximum=5.0, value=1.0, step=0.05, label="Temperature (sampling)") | |
| exclude_current = gr.Checkbox(value=True, label="Exclude current token when sampling") | |
| with gr.Row(): | |
| btn_step_once = gr.Button("Step Once") | |
| btn_live = gr.Button("Denoise Live (streaming)") | |
| gr.Markdown("### Noise by Indices") | |
| with gr.Row(): | |
| indices_csv = gr.Textbox( | |
| label="Positions to noise (enter like: 0, 5, 10-20)", | |
| placeholder="e.g., 0, 5, 10-20" | |
| ) | |
| with gr.Row(): | |
| add_left = gr.Number(value=0, precision=0, label="Noise tokens to add at START") | |
| add_right = gr.Number(value=0, precision=0, label="Noise tokens to add at END") | |
| btn_apply_noise = gr.Button("Apply Noise") | |
| # --- Wiring --- | |
| btn_random.click(init_random, [src, seqlen, seed], [ids_state, current_text, status]) | |
| # Manual indices + prepend/append noise | |
| btn_apply_noise.click( | |
| apply_noise, | |
| [src, ids_state, seqlen, indices_csv, add_left, add_right, seed], | |
| [ids_state, current_text, status] | |
| ) | |
| btn_step_once.click( | |
| step_once, | |
| [src, ids_state, update_mode, temperature, exclude_current], | |
| [ids_state, current_text, status] | |
| ) | |
| btn_live.click( | |
| live_denoise, | |
| [src, ids_state, steps, snap_every, seed, update_mode, temperature, exclude_current], | |
| [ids_state, current_text, status], | |
| show_progress=True | |
| ) | |
| demo.queue().launch() | |