the-puzzler commited on
Commit
d039848
·
1 Parent(s): 183d38f
Files changed (2) hide show
  1. README.md +8 -0
  2. app.py +263 -0
README.md CHANGED
@@ -11,3 +11,11 @@ short_description: Local Self Attention Allows text to coherently self evolve
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
14
+
15
+ # CNA — Simple Sampling (Strategy 1)
16
+
17
+ This Space runs the **random position → argmax** update over a fixed-length sequence using your CNA checkpoint.
18
+
19
+ ## Quick Start
20
+
21
+ 1. In your Space repo, add your trained checkpoint as:
app.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import os, re, math, random
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ import gradio as gr
7
+ from transformers import AutoTokenizer
8
+
9
+ # -----------------------------
10
+ # Minimal CNA (inference-ready)
11
+ # -----------------------------
12
+ class AttnBlock(nn.Module):
13
+ def __init__(self, embed_dim, num_heads, expansion_factor):
14
+ super().__init__()
15
+ assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
16
+ self.embed_dim = embed_dim
17
+ self.num_heads = num_heads
18
+ self.head_dim = embed_dim // num_heads
19
+
20
+ self.norm1 = nn.LayerNorm(embed_dim)
21
+ self.QKV = nn.Linear(embed_dim, embed_dim * 3)
22
+ self.Wo = nn.Linear(embed_dim, embed_dim)
23
+
24
+ self.norm2 = nn.LayerNorm(embed_dim)
25
+ self.mlp = nn.Sequential(
26
+ nn.Linear(embed_dim, embed_dim * expansion_factor),
27
+ nn.GELU(),
28
+ nn.Linear(embed_dim * expansion_factor, embed_dim),
29
+ )
30
+
31
+ # match training's zero-init on residual branches
32
+ nn.init.zeros_(self.Wo.weight); nn.init.zeros_(self.Wo.bias)
33
+ nn.init.zeros_(self.mlp[-1].weight); nn.init.zeros_(self.mlp[-1].bias)
34
+
35
+ def rope(self, Qh, Kh_seq, cos, sin):
36
+ Qe = Qh[..., 0::2]; Qo = Qh[..., 1::2]
37
+ ce = cos[..., 0::2]; se = sin[..., 0::2]
38
+ Qr_e = Qe * ce - Qo * se
39
+ Qr_o = Qe * se + Qo * ce
40
+ Qh2 = torch.empty_like(Qh); Qh2[..., 0::2] = Qr_e; Qh2[..., 1::2] = Qr_o
41
+
42
+ Ke = Kh_seq[..., 0::2]; Ko = Kh_seq[..., 1::2]
43
+ Kr_e = Ke * ce - Ko * se
44
+ Kr_o = Ke * se + Ko * ce
45
+ Kh2 = torch.empty_like(Kh_seq); Kh2[..., 0::2] = Kr_e; Kh2[..., 1::2] = Kr_o
46
+ return Qh2, Kh2
47
+
48
+ def forward(self, x, rope, radius):
49
+ h = self.norm1(x)
50
+ B, S, E = h.shape
51
+ cos, sin = rope
52
+ nh, hd = self.num_heads, self.head_dim
53
+
54
+ cos = cos.to(h.dtype).to(h.device).permute(0,2,1,3) # [1,1,S,hd]
55
+ sin = sin.to(h.dtype).to(h.device).permute(0,2,1,3)
56
+
57
+ # local band mask
58
+ idx = torch.arange(S, device=h.device)
59
+ idx_dist = (idx.view(1, S) - idx.view(S, 1)).abs()
60
+ neg_inf = torch.finfo(h.dtype).min
61
+ mask = torch.full((S, S), neg_inf, dtype=h.dtype, device=h.device)
62
+ mask[idx_dist <= int(radius)] = 0
63
+ mask = mask.view(1, 1, S, S)
64
+
65
+ qkv = self.QKV(h)
66
+ q, k, v = qkv.chunk(3, dim=-1)
67
+
68
+ Qh = q.view(B,S,nh,hd).permute(0,2,1,3).contiguous()
69
+ Kh_seq = k.view(B,S,nh,hd).permute(0,2,1,3).contiguous()
70
+ Vh = v.view(B,S,nh,hd).permute(0,2,1,3).contiguous()
71
+
72
+ assert hd % 2 == 0, "rope needs even head_dim"
73
+ Qh, Kh_seq = self.rope(Qh, Kh_seq, cos, sin)
74
+ Kh = Kh_seq.permute(0,1,3,2).contiguous()
75
+
76
+ logits = (Qh @ Kh) * (hd ** -0.5)
77
+ attn = F.softmax(logits + mask, dim=-1) @ Vh
78
+ attn = attn.permute(0,2,1,3).contiguous().view(B,S,E)
79
+
80
+ x = x + self.Wo(attn)
81
+ x = x + self.mlp(self.norm2(x))
82
+ return x
83
+
84
+ class CNA(nn.Module):
85
+ def __init__(self, embed_dim, num_heads, expansion_factor, num_blocks, radius, vocab_size):
86
+ super().__init__()
87
+ self.embed_dim = embed_dim
88
+ self.num_heads = num_heads
89
+ self.expansion_factor = expansion_factor
90
+ self.num_blocks = num_blocks
91
+ self.vocab_size = vocab_size
92
+ self.radius = radius
93
+ self.tok_emb = nn.Embedding(vocab_size, embed_dim)
94
+ self.blocks = nn.ModuleList([AttnBlock(embed_dim, num_heads, expansion_factor) for _ in range(num_blocks)])
95
+ self.proj = nn.Linear(embed_dim, vocab_size)
96
+
97
+ def _rope_seq(self, S, hd, device, dtype, base=10000.0):
98
+ pos = torch.arange(S, device=device, dtype=dtype)
99
+ half = hd // 2
100
+ idx = torch.arange(half, device=device, dtype=dtype)
101
+ inv = base ** (-idx / half)
102
+ ang = pos[:, None] * inv[None, :]
103
+ cos = ang.cos().unsqueeze(0).unsqueeze(2)
104
+ sin = ang.sin().unsqueeze(0).unsqueeze(2)
105
+ cos = torch.stack((cos, cos), dim=-1).reshape(1, S, 1, hd)
106
+ sin = torch.stack((sin, sin), dim=-1).reshape(1, S, 1, hd)
107
+ return cos, sin
108
+
109
+ def forward(self, x):
110
+ if x.dtype == torch.long and x.dim() == 2:
111
+ h = self.tok_emb(x)
112
+ else:
113
+ h = x
114
+ B, S, E = h.shape
115
+ hd = self.embed_dim // self.num_heads
116
+ cos, sin = self._rope_seq(S, hd, h.device, h.dtype)
117
+ for blk in self.blocks:
118
+ h = blk(h, rope=(cos, sin), radius=self.radius)
119
+ return self.proj(h)
120
+
121
+ # -----------------------------
122
+ # Helpers (trimmed to Strategy 1)
123
+ # -----------------------------
124
+ def infer_expansion_factor_from_state(state, embed_dim):
125
+ for key in ("blocks.0.mlp.0.weight", "blocks.0.mlp.2.weight"):
126
+ if key in state:
127
+ W = state[key]
128
+ if key.endswith("0.weight"):
129
+ return int(W.shape[0] // embed_dim)
130
+ else:
131
+ return int(W.shape[1] // embed_dim)
132
+ return 4
133
+
134
+ @torch.no_grad()
135
+ def decode(ids, tokenizer, max_chars=220):
136
+ s = tokenizer.decode(ids.tolist(), skip_special_tokens=True)
137
+ s = s.replace("\n", " ")
138
+ return s[:max_chars] + ("…" if len(s) > max_chars else "")
139
+
140
+ @torch.no_grad()
141
+ def model_logits(model, x):
142
+ return model(x)
143
+
144
+ # -----------------------------
145
+ # Load checkpoint & build model
146
+ # -----------------------------
147
+ def load_model(ckpt_path: str):
148
+ if not os.path.exists(ckpt_path):
149
+ raise FileNotFoundError(
150
+ f"Checkpoint not found at {ckpt_path}. "
151
+ "Upload ckpt_latest.pt to the repo root or set the correct path."
152
+ )
153
+ payload = torch.load(ckpt_path, map_location="cpu")
154
+ state = payload["model"]
155
+ cfg = payload.get("config", {}) or {}
156
+
157
+ # Carry over config (robust fallbacks)
158
+ embed_dim = cfg.get("embed_dim")
159
+ num_heads = cfg.get("num_heads")
160
+ num_blocks = cfg.get("num_blocks")
161
+ radius = cfg.get("radius")
162
+ expansion_factor = cfg.get("expansion_factor")
163
+
164
+ if embed_dim is None: embed_dim = state["tok_emb.weight"].shape[1]
165
+ if num_blocks is None:
166
+ block_idxs = [int(m.group(1)) for k in state.keys() for m in [re.match(r"blocks\.(\d+)\.", k)] if m]
167
+ num_blocks = max(block_idxs) + 1 if block_idxs else 1
168
+ if num_heads is None: num_heads = 8
169
+ if radius is None: radius = 16
170
+ if expansion_factor is None:
171
+ expansion_factor = infer_expansion_factor_from_state(state, embed_dim)
172
+ else:
173
+ expansion_factor = int(expansion_factor)
174
+
175
+ tokenizer_name = payload.get("tokenizer_name", "gpt2")
176
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_fast=True)
177
+ if tokenizer.pad_token is None:
178
+ tokenizer.pad_token = tokenizer.eos_token
179
+ tokenizer.model_max_length = 1_000_000_000
180
+ vocab_size = tokenizer.vocab_size
181
+
182
+ model = CNA(
183
+ int(embed_dim), int(num_heads), int(expansion_factor),
184
+ int(num_blocks), int(radius), int(vocab_size)
185
+ )
186
+
187
+ # Load weights (tolerate proj head size diff)
188
+ missing, unexpected = model.load_state_dict(state, strict=False)
189
+ if any(k.startswith("proj.") for k in missing):
190
+ with torch.no_grad():
191
+ nn.init.normal_(model.proj.weight, std=0.02)
192
+ nn.init.zeros_(model.proj.bias)
193
+ else:
194
+ model.load_state_dict(state, strict=True)
195
+
196
+ model.eval()
197
+ return model, tokenizer, int(radius)
198
+
199
+ # -----------------------------
200
+ # Simplest sampling: Strategy 1
201
+ # -----------------------------
202
+ @torch.no_grad()
203
+ def strategy_random_argmax(model, tokenizer, seqlen=100, steps=200, snap_every=20, seed=0, max_chars=220):
204
+ random.seed(seed); torch.manual_seed(seed)
205
+ V = tokenizer.vocab_size
206
+ x = torch.randint(0, V, (1, seqlen))
207
+ snaps = [(0, decode(x[0].cpu(), tokenizer, max_chars))]
208
+ for t in range(1, steps + 1):
209
+ pos = int(torch.randint(0, seqlen, (1,)))
210
+ logits_pos = model_logits(model, x)[0, pos] # [V]
211
+ x[0, pos] = int(torch.argmax(logits_pos).item())
212
+ if (t % snap_every == 0) or (t == steps):
213
+ snaps.append((t, decode(x[0].cpu(), tokenizer, max_chars)))
214
+ return snaps
215
+
216
+ # -----------------------------
217
+ # Gradio UI
218
+ # -----------------------------
219
+ DEFAULT_CKPT = os.environ.get("CKPT_PATH", "ckpt_latest.pt")
220
+
221
+ model_cache = {"model": None, "tokenizer": None, "radius": None, "ckpt": None}
222
+ def ensure_model(ckpt_path):
223
+ if model_cache["model"] is None or model_cache["ckpt"] != ckpt_path:
224
+ m, tok, rad = load_model(ckpt_path)
225
+ model_cache.update({"model": m, "tokenizer": tok, "radius": rad, "ckpt": ckpt_path})
226
+
227
+ def run_demo(ckpt_path, seqlen, steps, snap_every, seed, max_chars):
228
+ ensure_model(ckpt_path or DEFAULT_CKPT)
229
+ snaps = strategy_random_argmax(
230
+ model_cache["model"], model_cache["tokenizer"],
231
+ seqlen=seqlen, steps=steps, snap_every=snap_every,
232
+ seed=seed, max_chars=max_chars
233
+ )
234
+ # Pretty print log
235
+ log = "\n".join([f"t={t:>3}: {txt}" for (t, txt) in snaps])
236
+ final_text = snaps[-1][1] if snaps else ""
237
+ return log, final_text
238
+
239
+ with gr.Blocks(title="CNA — Simple Sampling (Random Position • Argmax)") as demo:
240
+ gr.Markdown(
241
+ """
242
+ # CNA — Simple Sampling (Strategy 1)
243
+ This Space loads your checkpoint and runs the **random position → argmax** update for a fixed-length sequence.
244
+ - Put your checkpoint at `ckpt_latest.pt` (repo root), or set a custom path below.
245
+ """
246
+ )
247
+ with gr.Row():
248
+ ckpt = gr.Textbox(value=DEFAULT_CKPT, label="Checkpoint path", placeholder="ckpt_latest.pt")
249
+ with gr.Row():
250
+ seqlen = gr.Slider(10, 512, value=100, step=1, label="Sequence length (S)")
251
+ steps = gr.Slider(10, 1000, value=200, step=1, label="Steps")
252
+ snap_every = gr.Slider(1, 200, value=20, step=1, label="Snapshot every N steps")
253
+ with gr.Row():
254
+ seed = gr.Slider(0, 10_000, value=0, step=1, label="Seed")
255
+ max_chars = gr.Slider(32, 1000, value=220, step=1, label="Max chars per snapshot")
256
+ run_btn = gr.Button("Run")
257
+ with gr.Row():
258
+ log_out = gr.Textbox(lines=18, label="Snapshots")
259
+ final_out = gr.Textbox(lines=6, label="Final text (last snapshot)")
260
+
261
+ run_btn.click(run_demo, [ckpt, seqlen, steps, snap_every, seed, max_chars], [log_out, final_out])
262
+
263
+ demo.queue(concurrency_count=1).launch()