AbstractPhil commited on
Commit
3c6b358
·
verified ·
1 Parent(s): 5d154e8

Create inference.py

Browse files
Files changed (1) hide show
  1. inference.py +665 -0
inference.py ADDED
@@ -0,0 +1,665 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ K-Simplex Language Model - Inference Script
4
+
5
+ Loads a trained k-simplex LLM checkpoint and generates text using
6
+ geometrically-validated autoregressive sampling.
7
+
8
+ Usage:
9
+ python inference.py --checkpoint checkpoint_epoch_008.pt --prompt "ROMEO: "
10
+ python inference.py --repo AbstractPhil/ksimplex-llm-prototype --prompt "To be or not"
11
+ """
12
+
13
+ import argparse
14
+ import json
15
+ import math
16
+ import torch
17
+ import torch.nn as nn
18
+ import torch.nn.functional as F
19
+ import tiktoken
20
+ from pathlib import Path
21
+ from huggingface_hub import hf_hub_download
22
+
23
+
24
+ # =============================================================================
25
+ # GEOMETRIC CORE
26
+ # =============================================================================
27
+
28
+ def factorial(n: int) -> int:
29
+ return math.factorial(n)
30
+
31
+
32
+ def cayley_menger_volume_squared(vertices: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
33
+ """
34
+ Compute squared volume via Cayley-Menger determinant.
35
+
36
+ Args:
37
+ vertices: [*, nv, edim] vertex coordinates
38
+
39
+ Returns:
40
+ d2: [*, n_pairs] squared distances
41
+ vol2: [*] squared volume
42
+ """
43
+ nv = vertices.shape[-2]
44
+ k = nv - 1 # simplex dimension
45
+
46
+ # Pairwise squared distances
47
+ diff = vertices.unsqueeze(-2) - vertices.unsqueeze(-3) # [*, nv, nv, edim]
48
+ d2_matrix = (diff ** 2).sum(-1) # [*, nv, nv]
49
+
50
+ # Extract upper triangle (pairs)
51
+ idx = torch.triu_indices(nv, nv, offset=1)
52
+ d2 = d2_matrix[..., idx[0], idx[1]] # [*, n_pairs]
53
+
54
+ # Build Cayley-Menger matrix
55
+ batch_shape = vertices.shape[:-2]
56
+ size = nv + 1
57
+ cm = torch.zeros(*batch_shape, size, size, device=vertices.device, dtype=vertices.dtype)
58
+
59
+ # First row/col: [0, 1, 1, ..., 1]
60
+ cm[..., 0, 1:] = 1.0
61
+ cm[..., 1:, 0] = 1.0
62
+
63
+ # Fill distance submatrix
64
+ cm[..., 1:, 1:] = d2_matrix
65
+
66
+ # Diagonal of distance submatrix is 0 (already set)
67
+
68
+ # Determinant
69
+ det = torch.linalg.det(cm)
70
+
71
+ # Volume formula: Vol² = (-1)^(k+1) * det(CM) / (2^k * (k!)²)
72
+ sign = (-1) ** (k + 1)
73
+ denom = (2 ** k) * (factorial(k) ** 2)
74
+ vol2 = sign * det / denom
75
+
76
+ return d2, vol2
77
+
78
+
79
+ # =============================================================================
80
+ # MODEL COMPONENTS
81
+ # =============================================================================
82
+
83
+ class SimplexTemplate(nn.Module):
84
+ """Generates regular simplex template vertices."""
85
+
86
+ def __init__(self, k: int, edim: int, scale: float = 1.0):
87
+ super().__init__()
88
+ self.k = k
89
+ self.nv = k + 1
90
+ self.edim = edim
91
+
92
+ # Regular simplex vertices (equilateral)
93
+ vertices = torch.zeros(self.nv, edim)
94
+ for i in range(self.nv):
95
+ angle = 2 * math.pi * i / self.nv
96
+ vertices[i, 0] = scale * math.cos(angle)
97
+ if edim > 1:
98
+ vertices[i, 1] = scale * math.sin(angle)
99
+ if edim > 2:
100
+ vertices[i, 2] = scale * 0.3 * math.cos(angle * 2)
101
+ for d in range(3, edim):
102
+ vertices[i, d] = scale * 0.1 * math.sin(angle * (d + 1))
103
+
104
+ self.register_buffer('template', vertices)
105
+
106
+ def forward(self) -> torch.Tensor:
107
+ return self.template
108
+
109
+
110
+ class KSimplexChannel(nn.Module):
111
+ """Single k-simplex channel with geometric validation."""
112
+
113
+ def __init__(self, k: int, edim: int, hidden: int, feat_dim: int, base_deform: float = 0.05):
114
+ super().__init__()
115
+ self.k = k
116
+ self.nv = k + 1
117
+ self.edim = edim
118
+ self.feat_dim = feat_dim
119
+ self.base_deform = base_deform
120
+
121
+ # Template
122
+ self.template = SimplexTemplate(k, edim)
123
+
124
+ # Projections
125
+ self._to_coords = nn.Linear(hidden, self.nv * edim)
126
+ self._to_feats = nn.Linear(hidden, self.nv * feat_dim)
127
+
128
+ # Geometry dimension: n_pairs + 1 (vol²)
129
+ n_pairs = (self.nv * (self.nv - 1)) // 2
130
+ self.geo_dim = n_pairs + 1
131
+
132
+ # Geometric gate
133
+ self._geo_gate = nn.Sequential(
134
+ nn.Linear(self.geo_dim, feat_dim),
135
+ nn.Sigmoid()
136
+ )
137
+
138
+ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
139
+ """
140
+ Args:
141
+ x: [*, hidden]
142
+
143
+ Returns:
144
+ out: [*, feat_dim + geo_dim] gated features + geometry
145
+ vol2: [*] squared volume for validity loss
146
+ mean_d2: [*] mean squared distance
147
+ """
148
+ # Vertex coordinates
149
+ coords = self._to_coords(x).unflatten(-1, (self.nv, self.edim))
150
+ verts = self.template() + self.base_deform * coords
151
+
152
+ # Vertex features
153
+ vert_feats = self._to_feats(x).unflatten(-1, (self.nv, self.feat_dim))
154
+
155
+ # Cayley-Menger
156
+ d2, vol2 = cayley_menger_volume_squared(verts)
157
+
158
+ # Geometry vector
159
+ geo = torch.cat([d2, vol2.unsqueeze(-1)], dim=-1)
160
+
161
+ # Gate features by geometry
162
+ gate = self._geo_gate(geo)
163
+ validity = torch.sigmoid(vol2 * 1e6).unsqueeze(-1)
164
+
165
+ # Aggregate vertex features
166
+ feat_agg = vert_feats.mean(dim=-2) * gate * validity
167
+
168
+ # Output
169
+ out = torch.cat([feat_agg, geo], dim=-1)
170
+
171
+ return out, vol2, d2.mean(dim=-1)
172
+
173
+
174
+ class TokenToKChannels(nn.Module):
175
+ """Project token embeddings to k-simplex channels."""
176
+
177
+ def __init__(self, embed_dim: int, hidden: int, depth: int, edim: int, feat_dim: int):
178
+ super().__init__()
179
+ self.depth = depth
180
+
181
+ self._proj = nn.Linear(embed_dim, hidden)
182
+ self._channels = nn.ModuleList([
183
+ KSimplexChannel(k=k+1, edim=edim, hidden=hidden, feat_dim=feat_dim)
184
+ for k in range(depth)
185
+ ])
186
+
187
+ # Compute output dimension (max across k-levels, then pad)
188
+ self.out_dims = [ch.feat_dim + ch.geo_dim for ch in self._channels]
189
+ self.max_dim = max(self.out_dims)
190
+
191
+ # Padding projections to equalize dimensions
192
+ self._pads = nn.ModuleList([
193
+ nn.Linear(d, self.max_dim) if d != self.max_dim else nn.Identity()
194
+ for d in self.out_dims
195
+ ])
196
+
197
+ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, list[torch.Tensor], list[torch.Tensor]]:
198
+ """
199
+ Args:
200
+ x: [B, T, embed_dim]
201
+
202
+ Returns:
203
+ out: [B, T, K, max_dim]
204
+ vol2_list: list of [B, T] per k
205
+ d2_list: list of [B, T] per k
206
+ """
207
+ h = self._proj(x) # [B, T, hidden]
208
+
209
+ outputs = []
210
+ vol2_list = []
211
+ d2_list = []
212
+
213
+ for ch, pad in zip(self._channels, self._pads):
214
+ out, vol2, d2 = ch(h)
215
+ outputs.append(pad(out))
216
+ vol2_list.append(vol2)
217
+ d2_list.append(d2)
218
+
219
+ # Stack: [B, T, K, max_dim]
220
+ out = torch.stack(outputs, dim=-2)
221
+
222
+ return out, vol2_list, d2_list
223
+
224
+
225
+ class KChannelCrossAttention(nn.Module):
226
+ """Cross-attention between k-levels at each position."""
227
+
228
+ def __init__(self, dim: int, num_heads: int = 4, dropout: float = 0.1):
229
+ super().__init__()
230
+ self.attn = nn.MultiheadAttention(dim, num_heads, dropout=dropout, batch_first=True)
231
+ self.norm = nn.LayerNorm(dim)
232
+
233
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
234
+ """
235
+ Args:
236
+ x: [B, T, K, D]
237
+ Returns:
238
+ [B, T, K, D]
239
+ """
240
+ B, T, K, D = x.shape
241
+
242
+ # Reshape to [B*T, K, D] - attention across K dimension
243
+ x_flat = x.view(B * T, K, D)
244
+
245
+ # Self-attention across k-levels
246
+ attn_out, _ = self.attn(x_flat, x_flat, x_flat)
247
+
248
+ # Residual + norm
249
+ out = self.norm(x_flat + attn_out)
250
+
251
+ return out.view(B, T, K, D)
252
+
253
+
254
+ class CausalSequenceAttention(nn.Module):
255
+ """Causal attention across sequence positions."""
256
+
257
+ def __init__(self, dim: int, num_heads: int, max_seq_len: int, dropout: float = 0.1):
258
+ super().__init__()
259
+ self.attn = nn.MultiheadAttention(dim, num_heads, dropout=dropout, batch_first=True)
260
+ self.norm = nn.LayerNorm(dim)
261
+
262
+ # Causal mask
263
+ mask = torch.tril(torch.ones(max_seq_len, max_seq_len)).bool()
264
+ self.register_buffer('_causal_mask', mask)
265
+
266
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
267
+ """
268
+ Args:
269
+ x: [B, T, K, D]
270
+ Returns:
271
+ [B, T, K, D]
272
+ """
273
+ B, T, K, D = x.shape
274
+
275
+ # Flatten K into D: [B, T, K*D]
276
+ x_flat = x.view(B, T, K * D)
277
+
278
+ # Causal mask
279
+ mask = self._causal_mask[:T, :T]
280
+ attn_mask = ~mask # True = masked
281
+
282
+ # Self-attention across sequence
283
+ attn_out, _ = self.attn(
284
+ x_flat, x_flat, x_flat,
285
+ attn_mask=attn_mask.float().masked_fill(attn_mask, float('-inf'))
286
+ )
287
+
288
+ # Residual + norm
289
+ out = self.norm(x_flat + attn_out)
290
+
291
+ return out.view(B, T, K, D)
292
+
293
+
294
+ class GeoBlock(nn.Module):
295
+ """Geometric block: k-channel attention + causal sequence attention + MLP."""
296
+
297
+ def __init__(self, dim: int, num_heads: int, max_seq_len: int, depth: int, dropout: float = 0.1):
298
+ super().__init__()
299
+ self.k_attn = KChannelCrossAttention(dim, num_heads=4, dropout=dropout)
300
+ self.seq_attn = CausalSequenceAttention(dim, num_heads, max_seq_len, dropout)
301
+
302
+ self.mlp = nn.Sequential(
303
+ nn.Linear(dim * depth, dim * depth * 4),
304
+ nn.GELU(),
305
+ nn.Dropout(dropout),
306
+ nn.Linear(dim * depth * 4, dim * depth),
307
+ nn.Dropout(dropout),
308
+ )
309
+ self.mlp_norm = nn.LayerNorm(dim * depth)
310
+ self.depth = depth
311
+
312
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
313
+ """
314
+ Args:
315
+ x: [B, T, K, D]
316
+ Returns:
317
+ [B, T, K, D]
318
+ """
319
+ # K-channel attention
320
+ x = self.k_attn(x)
321
+
322
+ # Sequence attention
323
+ x = self.seq_attn(x)
324
+
325
+ # MLP on flattened k-channels
326
+ B, T, K, D = x.shape
327
+ x_flat = x.view(B, T, K * D)
328
+ x_flat = self.mlp_norm(x_flat + self.mlp(x_flat))
329
+
330
+ return x_flat.view(B, T, K, D)
331
+
332
+
333
+ class KSimplexLM(nn.Module):
334
+ """K-Simplex Language Model."""
335
+
336
+ def __init__(
337
+ self,
338
+ vocab_size: int = 50257,
339
+ max_seq_len: int = 256,
340
+ embed_dim: int = 384,
341
+ depth: int = 4,
342
+ edim: int = 16,
343
+ feat_dim: int = 96,
344
+ hidden: int = 384,
345
+ num_heads: int = 8,
346
+ num_blocks: int = 8,
347
+ dropout: float = 0.1,
348
+ ):
349
+ super().__init__()
350
+ self.vocab_size = vocab_size
351
+ self.max_seq_len = max_seq_len
352
+ self.depth = depth
353
+
354
+ # Token embedding
355
+ self.embed = nn.Embedding(vocab_size, embed_dim)
356
+ self.pos_embed = nn.Embedding(max_seq_len, embed_dim)
357
+ self.embed_drop = nn.Dropout(dropout)
358
+
359
+ # Token to k-channels
360
+ self.to_k_channels = TokenToKChannels(embed_dim, hidden, depth, edim, feat_dim)
361
+
362
+ # Geometric blocks
363
+ k_dim = self.to_k_channels.max_dim
364
+ self.blocks = nn.ModuleList([
365
+ GeoBlock(k_dim, num_heads, max_seq_len, depth, dropout)
366
+ for _ in range(num_blocks)
367
+ ])
368
+
369
+ # LM head
370
+ self.ln_f = nn.LayerNorm(k_dim * depth)
371
+ self.lm_head = nn.Linear(k_dim * depth, vocab_size, bias=False)
372
+
373
+ # Weight tying
374
+ # self.lm_head.weight = self.embed.weight # Optional
375
+
376
+ self._init_weights()
377
+
378
+ def _init_weights(self):
379
+ for m in self.modules():
380
+ if isinstance(m, nn.Linear):
381
+ nn.init.normal_(m.weight, std=0.02)
382
+ if m.bias is not None:
383
+ nn.init.zeros_(m.bias)
384
+ elif isinstance(m, nn.Embedding):
385
+ nn.init.normal_(m.weight, std=0.02)
386
+
387
+ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, dict]:
388
+ """
389
+ Args:
390
+ x: [B, T] token indices
391
+
392
+ Returns:
393
+ logits: [B, T, vocab_size]
394
+ geo_info: dict with vol2, d2 per k-level
395
+ """
396
+ B, T = x.shape
397
+
398
+ # Embeddings
399
+ pos = torch.arange(T, device=x.device).unsqueeze(0)
400
+ h = self.embed(x) + self.pos_embed(pos)
401
+ h = self.embed_drop(h)
402
+
403
+ # To k-channels
404
+ h, vol2_list, d2_list = self.to_k_channels(h)
405
+
406
+ # Geo blocks
407
+ for block in self.blocks:
408
+ h = block(h)
409
+
410
+ # LM head
411
+ h_flat = h.view(B, T, -1)
412
+ h_flat = self.ln_f(h_flat)
413
+ logits = self.lm_head(h_flat)
414
+
415
+ geo_info = {
416
+ 'vol2': vol2_list,
417
+ 'd2': d2_list,
418
+ }
419
+
420
+ return logits, geo_info
421
+
422
+
423
+ # =============================================================================
424
+ # INFERENCE UTILITIES
425
+ # =============================================================================
426
+
427
+ def load_model(
428
+ checkpoint_path: str = None,
429
+ repo_id: str = None,
430
+ device: str = None,
431
+ ) -> tuple[KSimplexLM, tiktoken.Encoding]:
432
+ """
433
+ Load model from checkpoint or HuggingFace Hub.
434
+
435
+ Args:
436
+ checkpoint_path: Local path to checkpoint
437
+ repo_id: HuggingFace repo ID (e.g., "AbstractPhil/ksimplex-llm-prototype")
438
+ device: Device to load to
439
+
440
+ Returns:
441
+ model: KSimplexLM
442
+ tokenizer: tiktoken encoding
443
+ """
444
+ if device is None:
445
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
446
+
447
+ # Load checkpoint
448
+ if repo_id:
449
+ checkpoint_path = hf_hub_download(repo_id, "checkpoint_latest.pt")
450
+ config_path = hf_hub_download(repo_id, "config.json")
451
+ with open(config_path) as f:
452
+ config = json.load(f)
453
+ elif checkpoint_path:
454
+ checkpoint = torch.load(checkpoint_path, map_location=device)
455
+ config = checkpoint.get('config', {}).get('model', {})
456
+ else:
457
+ raise ValueError("Must provide checkpoint_path or repo_id")
458
+
459
+ # Build model
460
+ model = KSimplexLM(
461
+ vocab_size=config.get('vocab_size', 50257),
462
+ max_seq_len=config.get('max_seq_len', 256),
463
+ embed_dim=config.get('embed_dim', 384),
464
+ depth=config.get('depth', 4),
465
+ edim=config.get('edim', 16),
466
+ feat_dim=config.get('feat_dim', 96),
467
+ hidden=config.get('hidden', 384),
468
+ num_heads=config.get('num_heads', 8),
469
+ num_blocks=config.get('num_blocks', 8),
470
+ dropout=0.0, # No dropout at inference
471
+ )
472
+
473
+ # Load weights
474
+ if repo_id:
475
+ checkpoint = torch.load(checkpoint_path, map_location=device)
476
+ state_dict = checkpoint.get('model_state_dict', checkpoint)
477
+ model.load_state_dict(state_dict)
478
+
479
+ model.to(device)
480
+ model.eval()
481
+
482
+ # Tokenizer
483
+ tokenizer = tiktoken.get_encoding("gpt2")
484
+
485
+ return model, tokenizer
486
+
487
+
488
+ @torch.no_grad()
489
+ def generate(
490
+ model: KSimplexLM,
491
+ tokenizer: tiktoken.Encoding,
492
+ prompt: str,
493
+ max_tokens: int = 100,
494
+ temperature: float = 0.8,
495
+ top_k: int = 50,
496
+ top_p: float = 0.9,
497
+ device: str = None,
498
+ ) -> str:
499
+ """
500
+ Generate text from prompt.
501
+
502
+ Args:
503
+ model: KSimplexLM model
504
+ tokenizer: tiktoken encoding
505
+ prompt: Input text prompt
506
+ max_tokens: Maximum tokens to generate
507
+ temperature: Sampling temperature
508
+ top_k: Top-k sampling
509
+ top_p: Nucleus sampling threshold
510
+ device: Device
511
+
512
+ Returns:
513
+ Generated text including prompt
514
+ """
515
+ if device is None:
516
+ device = next(model.parameters()).device
517
+
518
+ # Encode prompt
519
+ tokens = tokenizer.encode(prompt)
520
+ tokens = torch.tensor([tokens], dtype=torch.long, device=device)
521
+
522
+ # Generate
523
+ for _ in range(max_tokens):
524
+ # Truncate to max_seq_len
525
+ if tokens.shape[1] > model.max_seq_len:
526
+ tokens = tokens[:, -model.max_seq_len:]
527
+
528
+ # Forward
529
+ logits, geo_info = model(tokens)
530
+ logits = logits[:, -1, :] # Last position
531
+
532
+ # Temperature
533
+ logits = logits / temperature
534
+
535
+ # Top-k
536
+ if top_k > 0:
537
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
538
+ logits[logits < v[:, [-1]]] = float('-inf')
539
+
540
+ # Top-p (nucleus)
541
+ if top_p < 1.0:
542
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
543
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
544
+
545
+ # Remove tokens with cumulative probability above threshold
546
+ sorted_indices_to_remove = cumulative_probs > top_p
547
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
548
+ sorted_indices_to_remove[..., 0] = 0
549
+
550
+ indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
551
+ logits[indices_to_remove] = float('-inf')
552
+
553
+ # Sample
554
+ probs = F.softmax(logits, dim=-1)
555
+ next_token = torch.multinomial(probs, num_samples=1)
556
+
557
+ # Append
558
+ tokens = torch.cat([tokens, next_token], dim=1)
559
+
560
+ # Stop on EOS (optional)
561
+ if next_token.item() == tokenizer.eot_token:
562
+ break
563
+
564
+ # Decode
565
+ return tokenizer.decode(tokens[0].tolist())
566
+
567
+
568
+ @torch.no_grad()
569
+ def analyze_geometry(
570
+ model: KSimplexLM,
571
+ tokenizer: tiktoken.Encoding,
572
+ text: str,
573
+ device: str = None,
574
+ ) -> dict:
575
+ """
576
+ Analyze geometric properties of text encoding.
577
+
578
+ Args:
579
+ model: KSimplexLM model
580
+ tokenizer: tiktoken encoding
581
+ text: Input text
582
+ device: Device
583
+
584
+ Returns:
585
+ Dictionary with geometric statistics
586
+ """
587
+ if device is None:
588
+ device = next(model.parameters()).device
589
+
590
+ tokens = tokenizer.encode(text)
591
+ tokens = torch.tensor([tokens], dtype=torch.long, device=device)
592
+
593
+ _, geo_info = model(tokens)
594
+
595
+ stats = {}
596
+ for k, (vol2, d2) in enumerate(zip(geo_info['vol2'], geo_info['d2']), 1):
597
+ vol2_np = vol2.cpu().numpy()
598
+ d2_np = d2.cpu().numpy()
599
+
600
+ stats[f'k{k}'] = {
601
+ 'vol2_mean': float(vol2_np.mean()),
602
+ 'vol2_std': float(vol2_np.std()),
603
+ 'vol2_min': float(vol2_np.min()),
604
+ 'vol2_max': float(vol2_np.max()),
605
+ 'validity_rate': float((vol2_np > 0).mean()),
606
+ 'd2_mean': float(d2_np.mean()),
607
+ }
608
+
609
+ return stats
610
+
611
+
612
+ # =============================================================================
613
+ # CLI
614
+ # =============================================================================
615
+
616
+ def main():
617
+ parser = argparse.ArgumentParser(description='K-Simplex LLM Inference')
618
+ parser.add_argument('--checkpoint', type=str, help='Path to checkpoint file')
619
+ parser.add_argument('--repo', type=str, default='AbstractPhil/ksimplex-llm-prototype',
620
+ help='HuggingFace repo ID')
621
+ parser.add_argument('--prompt', type=str, default='ROMEO: ',
622
+ help='Text prompt')
623
+ parser.add_argument('--max_tokens', type=int, default=100,
624
+ help='Maximum tokens to generate')
625
+ parser.add_argument('--temperature', type=float, default=0.8,
626
+ help='Sampling temperature')
627
+ parser.add_argument('--top_k', type=int, default=50,
628
+ help='Top-k sampling')
629
+ parser.add_argument('--top_p', type=float, default=0.9,
630
+ help='Nucleus sampling threshold')
631
+ parser.add_argument('--analyze', action='store_true',
632
+ help='Analyze geometric properties instead of generating')
633
+
634
+ args = parser.parse_args()
635
+
636
+ print("Loading model...")
637
+ model, tokenizer = load_model(
638
+ checkpoint_path=args.checkpoint,
639
+ repo_id=args.repo if not args.checkpoint else None,
640
+ )
641
+ print(f"Model loaded on {next(model.parameters()).device}")
642
+
643
+ if args.analyze:
644
+ print(f"\nAnalyzing: {args.prompt}")
645
+ stats = analyze_geometry(model, tokenizer, args.prompt)
646
+ for k, kstats in stats.items():
647
+ print(f"\n{k}:")
648
+ for name, value in kstats.items():
649
+ print(f" {name}: {value:.6f}")
650
+ else:
651
+ print(f"\nGenerating from: {args.prompt}")
652
+ text = generate(
653
+ model, tokenizer, args.prompt,
654
+ max_tokens=args.max_tokens,
655
+ temperature=args.temperature,
656
+ top_k=args.top_k,
657
+ top_p=args.top_p,
658
+ )
659
+ print("\n" + "=" * 60)
660
+ print(text)
661
+ print("=" * 60)
662
+
663
+
664
+ if __name__ == '__main__':
665
+ main()