CharlesCNorton commited on
Commit
470f0a9
Β·
1 Parent(s): a6b5f0c

Add SmolLM2-360M architecture analysis, fix PositionExtractor tokenization

Browse files

ARCHITECTURE ANALYSIS
---------------------
- Add SMOLLM2_ARCHITECTURE.md: comprehensive technical reference (457 lines)
- 361.82M params, hidden_dim=960, 32 transformer layers
- Grouped Query Attention: 15 query heads, 5 KV heads (3:1 ratio)
- SwiGLU MLP: gate/up (960->2560), down (2560->960)
- RoPE position encoding (theta=100k, max 8192 tokens)
- Weight inventory: per-layer breakdown, parameter distribution

- Document critical tokenization behavior:
- Digits tokenized individually: token_id = 32 + digit_value
- "47 + 86" -> ['4', '7', ' +', ' ', '8', '6'] (6 tokens, not 8)
- Operator tokens: ' +'=1232, ' -'=731, ' *'=1672, ' >'=2986, ' <'=2067, ' =='=1758
- Space token: 216

- Hidden state analysis: Layer 31 (final) has std=1.34, ideal for extraction
- Add analyze_smollm2.py and smollm2_analysis.json for reproducibility

POSITIONEXTRACTOR FIX (model.py)
--------------------------------
Previous implementation had hardcoded position assumptions:
- Assumed 3 tokens for operand A (positions 0-2)
- Assumed 2 tokens for operator (positions 3-4)
- Assumed 3 tokens for operand B (positions 5-7)

This was wrong: "47 + 86" is 6 tokens with A at 0-1, op at 2, space at 3, B at 4-5

Fix implements dynamic token-based detection:
- DIGIT_TOKENS = set(range(32, 42)) for '0'-'9'
- OPERATOR_TOKENS dict maps token IDs to operation indices
- _find_operator_position() scans for known operator tokens
- _extract_digit_features() handles 1-3 digit operands with LEFT-PADDING
(ensures units digit always aligned regardless of number length)
- Now requires token_ids parameter for accurate parsing
- Returns op_indices_from_tokens for potential supervision signal

ARITHMETICMODEL UPDATES (model.py)
----------------------------------
- get_hidden_states() now returns (hidden, mask, token_ids)
- forward() passes token_ids to PositionExtractor when position_extract=True
- Handles variable return signatures across extractor types:
- Extractor: (result_bits, a_bits, b_bits, op_logits)
- PositionExtractor: + op_indices_from_tokens
- DigitExtractor: + a_digit_logits, b_digit_logits

TRAIN.PY UPDATES
----------------
- evaluate_llm() uses indexed outputs for compatibility with all extractors
- Training loop uses outputs[0], outputs[1], outputs[2], outputs[3]
- Sample predictions updated similarly

README.MD UPDATES
-----------------
- Add "Target Model: SmolLM2-360M-Instruct" section with architecture table
- Link to SMOLLM2_ARCHITECTURE.md for full technical reference
- Update Interface Layers section with actual Extractor/MultiHeadBitExtractor code
- Update Trainable Parameters with accurate counts (~4.4M for full Extractor)
- Update Training Strategy with actual loss components and commands
- Update Stage 3 progress with training infrastructure table
- Update Files section: split Core/LLM Integration, add new files
- Add references: SmolLM2 model card, Transformer paper, RoPE paper

VERIFICATION
------------
All operator detection tests pass:
5 + 3 -> A=5, B=3, op=add [OK]
47 + 86 -> A=47, B=86, op=add [OK]
127 - 28 -> A=127, B=28, op=sub [OK]
12 * 11 -> A=12, B=11, op=mul [OK]
200 > 50 -> A=200, B=50, op=gt [OK]
3 < 100 -> A=3, B=100, op=lt [OK]
42 == 42 -> A=42, B=42, op=eq [OK]

README.md CHANGED
@@ -308,6 +308,25 @@ We solve this by embedding **frozen, proven-correct arithmetic circuits** direct
308
 
309
  The model learns **call dispatch**, not arithmetic. The arithmetic is already solved.
310
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
311
  ### Architecture
312
 
313
  Standard MLP block with parallel circuit path:
@@ -323,7 +342,7 @@ x ──┬── MLP path ────────────────┬
323
  Augmented MLP forward pass:
324
 
325
  ```python
326
- def forward(x): # x: [batch, seq, d_model]
327
  # Original MLP path (unchanged)
328
  mlp_out = self.down_proj(silu(self.gate_proj(x)) * self.up_proj(x))
329
 
@@ -370,56 +389,75 @@ Full adder = 2 half-adders + carry OR, ~4 threshold layers.
370
 
371
  ### Interface Layers (Trainable)
372
 
373
- **BitExtractor** β€” Maps embedding β†’ two 8-bit operands:
374
 
375
  ```python
376
- class BitExtractor(nn.Module):
377
- def __init__(self, d_model):
378
- self.proj = nn.Linear(d_model, 16)
 
 
 
 
 
 
 
 
379
 
380
- def forward(self, x):
381
- logits = self.proj(x)
382
- bits = heaviside(logits) # STE for training
383
- return bits[..., :8], bits[..., 8:]
 
 
384
  ```
385
 
386
- **BitInjector** β€” Maps result bits β†’ embedding delta:
387
 
388
  ```python
389
- class BitInjector(nn.Module):
390
- def __init__(self, d_model):
391
- self.proj = nn.Linear(16, d_model)
392
- self.scale = nn.Parameter(torch.tensor(0.1))
393
-
394
- def forward(self, result_bits, flags):
395
- combined = torch.cat([result_bits, flags], dim=-1)
396
- return self.proj(combined) * self.scale
 
 
 
 
397
  ```
398
 
399
- **Router** β€” Decides when to use circuits:
400
 
401
  ```python
402
- class Router(nn.Module):
403
- def __init__(self, d_model):
404
- self.net = nn.Sequential(
405
- nn.Linear(d_model, 64), nn.ReLU(),
406
- nn.Linear(64, 2), nn.Softmax(dim=-1)
407
- )
 
 
408
  ```
409
 
410
  ### Trainable Parameters
411
 
412
- For SmolLM2-360M (d_model=960), augmenting 11 layers:
 
 
 
 
 
 
 
413
 
414
- | Component | Params/Layer |
415
- |-----------|-------------|
416
- | BitExtractor | 15,376 |
417
- | BitInjector | 16,321 |
418
- | Router | 61,698 |
419
- | OpSelector | ~31,000 |
420
- | **Total** | ~124,395 |
421
 
422
- **11 layers Γ— 124,395 = ~1.37M trainable parameters** (0.38% of model)
423
 
424
  ### Gradient Flow
425
 
@@ -438,20 +476,33 @@ class HeavisideSTE(torch.autograd.Function):
438
 
439
  ### Training Strategy
440
 
441
- 1. **Data**: Generate 8-bit arithmetic problems exhaustively (256Γ—256 = 65,536 unique)
442
- 2. **Loss**: Cross-entropy on answer tokens only (prompt masked with -100)
443
- 3. **Optimizer**: AdamW on interface params only, lr=1e-4
444
- 4. **Curriculum**: Single-digit β†’ two-digit β†’ full 8-bit β†’ adversarial (127+128, 255+1)
 
 
 
 
 
 
 
 
 
 
 
445
 
446
  ### Inference
447
 
448
- At inference, Heaviside is true step functionβ€”no approximation. If BitExtractor correctly extracts operands, the circuit **will** output the correct result. Circuit computation adds ~5-10% latency overhead.
449
 
450
  ### Target Performance
451
 
452
- | Model | Baseline | Target |
453
- |-------|----------|--------|
454
- | SmolLM2-360M | ~5-10% | >95% |
 
 
455
 
456
  The interface generalizes to **all** 65,536 8-bit additions once trainedβ€”no memorization, the circuits compute.
457
 
@@ -535,19 +586,37 @@ Head-to-head on 50 random cases: SmolLM2 got 7/50 (14%), circuits got 50/50 (100
535
  The actual challenge: train an interface that extracts operands and operations from LLM hidden states (not from pre-formatted bit inputs).
536
 
537
  ```
538
- "What is 47 + 86?"
539
  ↓
540
- [LLM hidden states]
541
  ↓
542
- BitExtractor (must LEARN: "47" β†’ 00101111, "86" β†’ 01010110)
543
- OpRouter (must LEARN: "+" β†’ add operation)
544
  ↓
545
  [Frozen threshold circuits]
546
  ↓
547
- [Result bits] β†’ "133"
548
  ```
549
 
550
- The `train_passthrough_*.py` files demonstrate that routing works when given labels, but this is trivialβ€”the real test is learning to parse from natural language.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
551
 
552
  #### Proof of Concept Scope
553
 
@@ -555,7 +624,11 @@ The `train_passthrough_*.py` files demonstrate that routing works when given lab
555
  - **Six operations**: ADD, SUB, MUL, GT, LT, EQ
556
  - **Pure ALU profile** (no memory access)
557
 
558
- **Current status**: Circuit validation complete. LLM hidden state extraction in development.
 
 
 
 
559
 
560
  ### Extension Roadmap
561
 
@@ -581,18 +654,26 @@ The following extensions are planned after proof-of-concept validation:
581
 
582
  ## Files
583
 
 
 
584
  | File | Description |
585
  |------|-------------|
586
- | `neural_computer.safetensors` | 15,685 tensors, 43,366 parameters (pure ALU profile) |
587
- | `eval.py` | Unified evaluation suite (6,738 tests, GPU-batched) |
588
- | `build.py` | Build tools with configurable memory partitioning |
589
  | `prune_weights.py` | Weight magnitude pruning (GPU-batched, binary search conflict resolution) |
590
- | `llm_integration/baseline.py` | SmolLM2-360M arithmetic baseline evaluation (11.90% fitness) |
591
- | `llm_integration/fitness.py` | Shared fitness function for randomized arithmetic tests |
592
- | `llm_integration/circuits.py` | Frozen threshold circuit wrapper with STE gradients |
593
- | `llm_integration/model.py` | Interface layer definitions (BitEncoder, OpRouter, BitDecoder) |
594
- | `llm_integration/train_passthrough.py` | Scaffolding: trains with pre-formatted bit inputs |
595
- | `llm_integration/train_passthrough_router.py` | Scaffolding: router-only with ground truth bits |
 
 
 
 
 
 
596
 
597
  ### Build Tool Usage
598
 
@@ -653,4 +734,6 @@ MIT
653
  3. Siegelmann & Sontag (1995). "On the Computational Power of Neural Nets"
654
  4. Bengio et al. (2013). "Estimating or Propagating Gradients Through Stochastic Neurons"
655
  5. Ma et al. (2024). "The Era of 1-bit LLMs" (BitNet b1.58)
656
- 6. HuggingFace (2024). "SmolLM2: Small Language Models"
 
 
 
308
 
309
  The model learns **call dispatch**, not arithmetic. The arithmetic is already solved.
310
 
311
+ ### Target Model: SmolLM2-360M-Instruct
312
+
313
+ We use HuggingFace's SmolLM2-360M-Instruct as our base model. See [`llm_integration/SMOLLM2_ARCHITECTURE.md`](llm_integration/SMOLLM2_ARCHITECTURE.md) for the complete technical analysis.
314
+
315
+ | Property | Value |
316
+ |----------|-------|
317
+ | Parameters | 361.82M |
318
+ | Hidden Dimension | **960** (matches extractor input) |
319
+ | Layers | 32 transformer blocks |
320
+ | Attention | 15 query heads, 5 KV heads (GQA) |
321
+ | MLP | SwiGLU (960β†’2560β†’960) |
322
+ | Position Encoding | RoPE (theta=100k, max 8192) |
323
+
324
+ **Key insight**: The hidden dimension of 960 exactly matches our extractor requirementsβ€”no projection layer needed.
325
+
326
+ **Tokenization**: Digits are tokenized individually (`"47 + 86"` β†’ `['4', '7', ' +', ' ', '8', '6']`), with digit token IDs following `token_id = 32 + digit_value`. This enables position-based operand extraction.
327
+
328
+ **Hidden State Extraction**: Layer 31 (final, pre-LM-head) provides well-normalized representations (std=1.34) ideal for bit extraction. All 33 hidden state outputs are available (embedding + 32 layers).
329
+
330
  ### Architecture
331
 
332
  Standard MLP block with parallel circuit path:
 
342
  Augmented MLP forward pass:
343
 
344
  ```python
345
+ def forward(x): # x: [batch, seq, d_model=960]
346
  # Original MLP path (unchanged)
347
  mlp_out = self.down_proj(silu(self.gate_proj(x)) * self.up_proj(x))
348
 
 
389
 
390
  ### Interface Layers (Trainable)
391
 
392
+ **Extractor** β€” Extracts operands and operation from LLM hidden states:
393
 
394
  ```python
395
+ class Extractor(nn.Module):
396
+ """Attention pooling + per-bit extraction networks."""
397
+
398
+ def __init__(self, hidden_dim=960):
399
+ self.attention_pool = AttentionPooling(hidden_dim, num_heads=4)
400
+ self.a_extractor = MultiHeadBitExtractor(hidden_dim) # 8 separate bit networks
401
+ self.b_extractor = MultiHeadBitExtractor(hidden_dim)
402
+ self.op_router = nn.Sequential(
403
+ nn.Linear(hidden_dim, 256), nn.GELU(),
404
+ nn.Linear(256, 6) # 6 operations
405
+ )
406
 
407
+ def forward(self, hidden_states, attention_mask):
408
+ pooled = self.attention_pool(hidden_states, attention_mask) # (batch, 960)
409
+ a_bits, _ = self.a_extractor(pooled) # (batch, 8)
410
+ b_bits, _ = self.b_extractor(pooled) # (batch, 8)
411
+ op_logits = self.op_router(pooled) # (batch, 6)
412
+ return a_bits, b_bits, op_logits
413
  ```
414
 
415
+ **MultiHeadBitExtractor** β€” 8 specialized networks, one per bit:
416
 
417
  ```python
418
+ class MultiHeadBitExtractor(nn.Module):
419
+ def __init__(self, hidden_dim=960):
420
+ self.bit_extractors = nn.ModuleList([
421
+ nn.Sequential(nn.Linear(hidden_dim, 128), nn.GELU(), nn.Linear(128, 1))
422
+ for _ in range(8)
423
+ ])
424
+
425
+ def forward(self, x):
426
+ logits = torch.cat([ext(x) for ext in self.bit_extractors], dim=-1)
427
+ soft = torch.sigmoid(logits)
428
+ hard = heaviside_ste(logits)
429
+ return hard - soft.detach() + soft, logits # STE
430
  ```
431
 
432
+ **AttentionPooling** β€” Learns which token positions matter:
433
 
434
  ```python
435
+ class AttentionPooling(nn.Module):
436
+ """CLS-token style pooling with learned attention."""
437
+
438
+ def __init__(self, hidden_dim=960, num_heads=4):
439
+ self.cls_token = nn.Parameter(torch.randn(1, 1, hidden_dim) * 0.02)
440
+ self.query = nn.Linear(hidden_dim, hidden_dim)
441
+ self.key = nn.Linear(hidden_dim, hidden_dim)
442
+ self.value = nn.Linear(hidden_dim, hidden_dim)
443
  ```
444
 
445
  ### Trainable Parameters
446
 
447
+ For SmolLM2-360M (hidden_dim=960):
448
+
449
+ | Component | Parameters | Description |
450
+ |-----------|------------|-------------|
451
+ | AttentionPooling | ~3.7M | 4-head attention over sequence |
452
+ | MultiHeadBitExtractor (Γ—2) | ~245K each | 8 per-bit MLPs for A and B |
453
+ | OpRouter | ~246K | 960β†’256β†’6 MLP |
454
+ | **Extractor Total** | ~4.4M | Full extraction module |
455
 
456
+ **Alternative architectures**:
457
+ - `PositionExtractor`: ~1.5M (position-specific, no attention)
458
+ - `DigitExtractor`: ~1.2M (predicts digits 0-9 instead of bits)
 
 
 
 
459
 
460
+ With `--unfreeze_layers 4`: Adds ~39.3M trainable params (top 4 transformer layers).
461
 
462
  ### Gradient Flow
463
 
 
476
 
477
  ### Training Strategy
478
 
479
+ 1. **Data**: Random 8-bit arithmetic problems (operands 0-255, 6 operations)
480
+ 2. **Loss**: Multi-component BCE + CE
481
+ - `result_loss`: BCE on output bits vs expected
482
+ - `a_loss`, `b_loss`: BCE on extracted bits vs ground truth (2Γ— weight)
483
+ - `op_loss`: CE on operation classification
484
+ 3. **Optimizer**: AdamW, lr=3e-4, gradient clipping at 1.0
485
+ 4. **Curriculum**: Epoch-based range expansion (0-9 β†’ 0-99 β†’ 0-255)
486
+ 5. **Batching**: 256-4096 samples per batch (VRAM-dependent)
487
+
488
+ ```bash
489
+ # Example training commands
490
+ python train.py --mode router --epochs 100 # Sanity check
491
+ python train.py --mode llm --epochs 100 --batch_size 256 # Frozen LLM
492
+ python train.py --mode llm --unfreeze_layers 4 --batch_size 4096 # Fine-tune top layers
493
+ ```
494
 
495
  ### Inference
496
 
497
+ At inference, Heaviside is true step functionβ€”no approximation. If the Extractor correctly identifies operands, the circuit **will** output the correct result.
498
 
499
  ### Target Performance
500
 
501
+ | Condition | Configuration | Accuracy |
502
+ |-----------|---------------|----------|
503
+ | Control | Vanilla SmolLM2-360M | 11.90% |
504
+ | Circuits only | Ground truth bits | 100.00% |
505
+ | Experimental | LLM + Extractor + Circuits | **Target: 100%** |
506
 
507
  The interface generalizes to **all** 65,536 8-bit additions once trainedβ€”no memorization, the circuits compute.
508
 
 
586
  The actual challenge: train an interface that extracts operands and operations from LLM hidden states (not from pre-formatted bit inputs).
587
 
588
  ```
589
+ "47 + 86"
590
  ↓
591
+ [SmolLM2 hidden states: (seq_len, 960)]
592
  ↓
593
+ Extractor (must LEARN: hidden β†’ a_bits, b_bits, op_logits)
 
594
  ↓
595
  [Frozen threshold circuits]
596
  ↓
597
+ [Result bits] β†’ 133
598
  ```
599
 
600
+ **Training Infrastructure** (`train.py`):
601
+
602
+ | Mode | Description | Status |
603
+ |------|-------------|--------|
604
+ | `--mode router` | Train OpRouter with ground truth bits | 100% achieved |
605
+ | `--mode interface` | Train BitEncoder + OpRouter | Ready |
606
+ | `--mode llm` | Train from LLM hidden states | Active development |
607
+
608
+ **LLM Mode Options**:
609
+ - `--unfreeze_layers N`: Fine-tune top N transformer layers
610
+ - `--extract_layer N`: Extract from intermediate layer (-1 = final)
611
+ - `--position_extract`: Position-specific extraction (uses token positions)
612
+ - `--digit_pred`: Predict digits (0-9) instead of bits
613
+
614
+ **Extraction Architectures** (`model.py`):
615
+ - `Extractor`: Attention pooling + per-bit MLPs
616
+ - `PositionExtractor`: Position-aware (operand A from positions 0-2, B from 5-7)
617
+ - `DigitExtractor`: Predicts 3 digits per operand, converts to bits
618
+
619
+ **Curriculum Learning**: Training progresses 0-9 β†’ 0-99 β†’ 0-255 over epochs.
620
 
621
  #### Proof of Concept Scope
622
 
 
624
  - **Six operations**: ADD, SUB, MUL, GT, LT, EQ
625
  - **Pure ALU profile** (no memory access)
626
 
627
+ **Current Status**:
628
+ - Circuit validation: Complete (100% on all operations)
629
+ - LLM baseline: Measured (11.90%)
630
+ - SmolLM2 architecture analysis: Complete (see `SMOLLM2_ARCHITECTURE.md`)
631
+ - Extraction training: In progress
632
 
633
  ### Extension Roadmap
634
 
 
654
 
655
  ## Files
656
 
657
+ ### Core
658
+
659
  | File | Description |
660
  |------|-------------|
661
+ | `neural_computer.safetensors` | Frozen threshold circuits (~8.29M params full, ~32K pure ALU) |
662
+ | `eval.py` | Unified evaluation suite (GPU-batched, exhaustive testing) |
663
+ | `build.py` | Circuit generator with configurable memory profiles |
664
  | `prune_weights.py` | Weight magnitude pruning (GPU-batched, binary search conflict resolution) |
665
+
666
+ ### LLM Integration (`llm_integration/`)
667
+
668
+ | File | Description |
669
+ |------|-------------|
670
+ | `SMOLLM2_ARCHITECTURE.md` | Complete technical analysis of SmolLM2-360M (layers, weights, tokenization) |
671
+ | `baseline.py` | SmolLM2-360M vanilla arithmetic evaluation (11.90% baseline) |
672
+ | `circuits.py` | Frozen threshold circuit wrapper with STE gradients |
673
+ | `fitness.py` | Shared fitness function (randomized arithmetic, no answer supervision) |
674
+ | `model.py` | Interface layers: `BitEncoder`, `OpRouter`, `Extractor`, `ArithmeticModel` |
675
+ | `train.py` | Unified training: `--mode router`, `--mode interface`, `--mode llm` |
676
+ | `trained/router.pt` | Trained OpRouter checkpoint (100% with ground truth bits) |
677
 
678
  ### Build Tool Usage
679
 
 
734
  3. Siegelmann & Sontag (1995). "On the Computational Power of Neural Nets"
735
  4. Bengio et al. (2013). "Estimating or Propagating Gradients Through Stochastic Neurons"
736
  5. Ma et al. (2024). "The Era of 1-bit LLMs" (BitNet b1.58)
737
+ 6. HuggingFace (2024). "SmolLM2: Small Language Models" β€” [Model Card](https://huggingface.co/HuggingFaceTB/SmolLM2-360M-Instruct)
738
+ 7. Vaswani et al. (2017). "Attention Is All You Need" β€” Transformer architecture
739
+ 8. Su et al. (2021). "RoFormer: Enhanced Transformer with Rotary Position Embedding" β€” RoPE
llm_integration/SMOLLM2_ARCHITECTURE.md ADDED
@@ -0,0 +1,456 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SmolLM2-360M-Instruct Architecture Analysis
2
+
3
+ Technical reference document for the 8bit-threshold-computer LLM integration project.
4
+
5
+ **Model**: `HuggingFaceTB/SmolLM2-360M-Instruct`
6
+ **Architecture**: LlamaForCausalLM (Llama 2 variant)
7
+ **Tokenizer**: GPT2TokenizerFast
8
+ **Analysis Date**: 2026-01-21
9
+
10
+ ---
11
+
12
+ ## 1. Executive Summary
13
+
14
+ SmolLM2-360M-Instruct is a 362M parameter causal language model using the Llama architecture. Key characteristics relevant to our bit extraction task:
15
+
16
+ - **Hidden dimension: 960** (matches our extractor input requirement)
17
+ - **32 transformer layers** providing multiple extraction points
18
+ - **Digit-level tokenization** for numbers (each digit is a separate token)
19
+ - **Grouped Query Attention (GQA)** with 15 query heads and 5 KV heads
20
+
21
+ ---
22
+
23
+ ## 2. Architecture Census
24
+
25
+ ### 2.1 Core Parameters
26
+
27
+ | Parameter | Value |
28
+ |-----------|-------|
29
+ | Total Parameters | 361,821,120 (361.82M) |
30
+ | Vocabulary Size | 49,152 |
31
+ | Hidden Dimension | 960 |
32
+ | Intermediate Dimension (MLP) | 2,560 |
33
+ | Number of Layers | 32 |
34
+ | Number of Attention Heads | 15 |
35
+ | Number of KV Heads | 5 (Grouped Query Attention) |
36
+ | Head Dimension | 64 |
37
+ | Max Sequence Length | 8,192 |
38
+ | Activation Function | SiLU |
39
+ | Normalization | RMSNorm (eps=1e-05) |
40
+ | Position Encoding | RoPE (theta=100,000) |
41
+ | Word Embedding Tying | Yes (embed_tokens = lm_head) |
42
+
43
+ ### 2.2 Architecture Diagram
44
+
45
+ ```
46
+ Input Token IDs
47
+ |
48
+ v
49
+ +------------------+
50
+ | Embedding Layer | (49152, 960)
51
+ +------------------+
52
+ |
53
+ v
54
+ +------------------+
55
+ | LlamaDecoderLayer| x 32
56
+ | +-------------+ |
57
+ | | RMSNorm | |
58
+ | +-------------+ |
59
+ | | Self-Attn | | Q: (960, 960), K: (960, 320), V: (960, 320), O: (960, 960)
60
+ | +-------------+ |
61
+ | | Residual | |
62
+ | +-------------+ |
63
+ | | RMSNorm | |
64
+ | +-------------+ |
65
+ | | MLP (SwiGLU)| | gate: (960, 2560), up: (960, 2560), down: (2560, 960)
66
+ | +-------------+ |
67
+ | | Residual | |
68
+ +------------------+
69
+ |
70
+ v
71
+ +------------------+
72
+ | Final RMSNorm | (960,)
73
+ +------------------+
74
+ |
75
+ v
76
+ +------------------+
77
+ | LM Head | (960, 49152) - tied with embeddings
78
+ +------------------+
79
+ |
80
+ v
81
+ Logits (batch, seq, 49152)
82
+ ```
83
+
84
+ ### 2.3 Parameter Distribution
85
+
86
+ | Component | Parameters | Percentage |
87
+ |-----------|-----------|------------|
88
+ | Embedding | 47,185,920 | 13.04% |
89
+ | All Attention Layers | 78,643,200 | 21.74% |
90
+ | All MLP Layers | 235,929,600 | 65.19% |
91
+ | All Layer Norms | 61,440 | 0.02% |
92
+ | Final Norm | 960 | 0.00% |
93
+
94
+ Per-layer breakdown (each of 32 layers):
95
+ - Attention: 2,457,600 params (0.68%)
96
+ - MLP: 7,372,800 params (2.04%)
97
+ - Norms: 1,920 params (0.00%)
98
+
99
+ ---
100
+
101
+ ## 3. Weight Inventory
102
+
103
+ ### 3.1 Embedding and Output Layers
104
+
105
+ | Parameter Name | Shape | Elements | Notes |
106
+ |---------------|-------|----------|-------|
107
+ | `model.embed_tokens.weight` | (49152, 960) | 47,185,920 | Token embeddings |
108
+ | `model.norm.weight` | (960,) | 960 | Final layer norm |
109
+ | `lm_head.weight` | (49152, 960) | (tied) | Tied to embed_tokens |
110
+
111
+ ### 3.2 Single Transformer Layer Structure
112
+
113
+ Each of the 32 layers (`model.layers.{0-31}`) contains:
114
+
115
+ **Attention Block:**
116
+ | Parameter | Shape | Elements |
117
+ |-----------|-------|----------|
118
+ | `self_attn.q_proj.weight` | (960, 960) | 921,600 |
119
+ | `self_attn.k_proj.weight` | (320, 960) | 307,200 |
120
+ | `self_attn.v_proj.weight` | (320, 960) | 307,200 |
121
+ | `self_attn.o_proj.weight` | (960, 960) | 921,600 |
122
+ | **Attention Total** | | **2,457,600** |
123
+
124
+ **MLP Block (SwiGLU):**
125
+ | Parameter | Shape | Elements |
126
+ |-----------|-------|----------|
127
+ | `mlp.gate_proj.weight` | (2560, 960) | 2,457,600 |
128
+ | `mlp.up_proj.weight` | (2560, 960) | 2,457,600 |
129
+ | `mlp.down_proj.weight` | (960, 2560) | 2,457,600 |
130
+ | **MLP Total** | | **7,372,800** |
131
+
132
+ **Normalization:**
133
+ | Parameter | Shape | Elements |
134
+ |-----------|-------|----------|
135
+ | `input_layernorm.weight` | (960,) | 960 |
136
+ | `post_attention_layernorm.weight` | (960,) | 960 |
137
+ | **Norms Total** | | **1,920** |
138
+
139
+ **Layer Total: 9,832,320 parameters**
140
+
141
+ ### 3.3 Grouped Query Attention (GQA) Details
142
+
143
+ SmolLM2 uses GQA with a 3:1 ratio:
144
+ - 15 query heads (Q dimension: 960 = 15 x 64)
145
+ - 5 key-value heads (KV dimension: 320 = 5 x 64)
146
+ - Each KV head is shared by 3 query heads
147
+ - This reduces KV cache memory by ~67% vs standard MHA
148
+
149
+ ---
150
+
151
+ ## 4. Tokenization Analysis
152
+
153
+ ### 4.1 Arithmetic Expression Tokenization
154
+
155
+ Test input: `"47 + 86"`
156
+
157
+ | Position | Token ID | Token | Description |
158
+ |----------|----------|-------|-------------|
159
+ | 0 | 36 | `'4'` | First digit of operand A |
160
+ | 1 | 39 | `'7'` | Second digit of operand A |
161
+ | 2 | 1232 | `' +'` | Space + plus sign |
162
+ | 3 | 216 | `' '` | Trailing space |
163
+ | 4 | 40 | `'8'` | First digit of operand B |
164
+ | 5 | 38 | `'6'` | Second digit of operand B |
165
+
166
+ ### 4.2 Digit Token Mappings
167
+
168
+ | Digit | Token ID |
169
+ |-------|----------|
170
+ | 0 | 32 |
171
+ | 1 | 33 |
172
+ | 2 | 34 |
173
+ | 3 | 35 |
174
+ | 4 | 36 |
175
+ | 5 | 37 |
176
+ | 6 | 38 |
177
+ | 7 | 39 |
178
+ | 8 | 40 |
179
+ | 9 | 41 |
180
+
181
+ Key observations:
182
+ - **Digits are tokenized individually** (no multi-digit tokens like "47")
183
+ - Digit tokens are sequential: ID = 32 + digit_value
184
+ - Space-prefixed operators exist (e.g., `' +'` = 1232)
185
+ - `'='` has token ID 45
186
+
187
+ ### 4.3 Implications for Bit Extraction
188
+
189
+ The digit-by-digit tokenization means:
190
+ 1. For `"47 + 86"`, operand A spans positions [0,1] and operand B spans positions [4,5]
191
+ 2. The model must learn to:
192
+ - Recognize digit boundaries
193
+ - Compose multi-digit numbers from sequential tokens
194
+ - Perform arithmetic across token positions
195
+ 3. Hidden states at digit positions contain the numerical representation
196
+
197
+ ---
198
+
199
+ ## 5. Hidden State Analysis
200
+
201
+ ### 5.1 Hidden State Output Structure
202
+
203
+ When running with `output_hidden_states=True`:
204
+ - Returns **33 hidden states** (embedding + 32 layer outputs)
205
+ - Each has shape: `(batch_size, seq_len, hidden_dim)`
206
+ - For `"47 + 86"`: `(1, 6, 960)`
207
+
208
+ ### 5.2 Hidden State Statistics by Layer
209
+
210
+ | Layer | Mean | Std Dev | Min | Max |
211
+ |-------|------|---------|-----|-----|
212
+ | Embedding | -0.001 | 0.105 | -0.44 | 1.77 |
213
+ | Layer 0 | -0.127 | 2.55 | -80.8 | 19.0 |
214
+ | Layer 1 | -0.171 | 3.70 | -161 | 39.7 |
215
+ | Layer 2 | -0.151 | 3.67 | -102 | 61.4 |
216
+ | Layer 3 | -1.13 | 327 | -21,722 | 11,856 |
217
+ | Layer 4-12 | ~-1.3 | ~327 | ~-21,700 | ~11,800 |
218
+ | Layer 13-26 | ~-1.5 | ~337 | ~-22,400 | ~12,100 |
219
+ | Layer 27-30 | ~-1.8 | ~310 | ~-20,000 | ~11,800 |
220
+ | Layer 31 | 0.017 | 1.34 | -18.9 | 34.3 |
221
+
222
+ Key observations:
223
+ 1. **Dramatic variance explosion at Layer 3**: Std dev jumps from ~4 to ~327
224
+ 2. **Stable middle layers (4-26)**: Consistent statistics, suggesting numerical computation
225
+ 3. **Compression at final layer**: Std dev drops to 1.34 at Layer 31 (pre-softmax normalization)
226
+ 4. **Layer 31 is well-scaled** for downstream processing
227
+
228
+ ### 5.3 Extraction Point Candidates
229
+
230
+ | Layer Range | Characteristics | Suitability |
231
+ |-------------|-----------------|-------------|
232
+ | 0-2 (Early) | Low variance, close to embeddings | Poor - minimal computation |
233
+ | 3-12 (Early-Mid) | High variance, initial processing | Moderate - may contain raw numerical features |
234
+ | 13-26 (Middle) | Stable high variance | Good - computation in progress |
235
+ | 27-30 (Late) | Variance compression begins | Good - refined representations |
236
+ | 31 (Final) | Well-normalized output | Best - final representation before LM head |
237
+
238
+ ---
239
+
240
+ ## 6. Relevance to 8bit-Threshold-Computer Project
241
+
242
+ ### 6.1 Hidden Dimension Match
243
+
244
+ **The hidden dimension of 960 exactly matches our extractor input requirement.** This is fortuitous as it means:
245
+ - No projection layer needed to interface with our bit extractor
246
+ - Direct extraction from any layer's hidden states
247
+ - Full utilization of the model's representational capacity
248
+
249
+ ### 6.2 Recommended Extraction Strategy
250
+
251
+ ```python
252
+ def extract_hidden_state(model, tokenizer, expression, layer=-1):
253
+ """
254
+ Extract hidden state for bit extraction.
255
+
256
+ Args:
257
+ layer: Which layer to extract from (default: final layer)
258
+ -1 = Layer 31 (final, pre-LM-head)
259
+
260
+ Returns:
261
+ Tensor of shape (960,) for the last token position
262
+ """
263
+ inputs = tokenizer(expression, return_tensors="pt")
264
+ outputs = model(**inputs, output_hidden_states=True)
265
+
266
+ # hidden_states[0] = embedding, hidden_states[1] = layer 0, ...
267
+ # hidden_states[32] = layer 31 (final)
268
+ hidden = outputs.hidden_states[layer] # (1, seq_len, 960)
269
+
270
+ # Extract last token position for autoregressive prediction
271
+ return hidden[0, -1, :] # (960,)
272
+ ```
273
+
274
+ ### 6.3 Token Position Analysis
275
+
276
+ For arithmetic expressions like `"A + B"`:
277
+
278
+ ```
279
+ Tokens: [d1] [d2] [ +] [ ] [d3] [d4]
280
+ Positions: 0 1 2 3 4 5
281
+
282
+ Operand A: positions 0 to (plus_pos - 1)
283
+ Operator: position where ' +' token appears
284
+ Operand B: positions (plus_pos + 2) to end
285
+ ```
286
+
287
+ Strategy for operand extraction:
288
+ 1. Find the `' +'` token (ID 1232) position
289
+ 2. Collect hidden states at digit positions before it (operand A)
290
+ 3. Collect hidden states at digit positions after it (operand B)
291
+ 4. Consider pooling (mean, max) or concatenating digit hidden states
292
+
293
+ ### 6.4 Attention Pattern Utilization
294
+
295
+ With GQA (15 query heads, 5 KV heads), we can analyze attention patterns to:
296
+ 1. Identify which positions attend to operand digits
297
+ 2. Determine if the model explicitly links corresponding digit positions
298
+ 3. Find heads that specialize in numerical reasoning
299
+
300
+ ```python
301
+ def get_attention_weights(model, tokenizer, expression):
302
+ inputs = tokenizer(expression, return_tensors="pt")
303
+ outputs = model(**inputs, output_attentions=True)
304
+ # attentions: tuple of (batch, num_heads, seq_len, seq_len) per layer
305
+ return outputs.attentions
306
+ ```
307
+
308
+ ### 6.5 Extraction Interface Specification
309
+
310
+ For integration with the threshold computer:
311
+
312
+ ```python
313
+ class SmolLM2Extractor:
314
+ """Interface between SmolLM2 and threshold-based bit extraction."""
315
+
316
+ def __init__(self, model, tokenizer, extraction_layer=31):
317
+ self.model = model
318
+ self.tokenizer = tokenizer
319
+ self.layer = extraction_layer + 1 # +1 because index 0 is embedding
320
+
321
+ def get_hidden_state(self, text: str) -> torch.Tensor:
322
+ """
323
+ Returns: Tensor of shape (960,) ready for bit extractor
324
+ """
325
+ tokens = self.tokenizer(text, return_tensors="pt")
326
+ with torch.no_grad():
327
+ outputs = self.model(**tokens, output_hidden_states=True)
328
+ return outputs.hidden_states[self.layer][0, -1, :]
329
+
330
+ def get_all_position_states(self, text: str) -> torch.Tensor:
331
+ """
332
+ Returns: Tensor of shape (seq_len, 960) for all positions
333
+ """
334
+ tokens = self.tokenizer(text, return_tensors="pt")
335
+ with torch.no_grad():
336
+ outputs = self.model(**tokens, output_hidden_states=True)
337
+ return outputs.hidden_states[self.layer][0]
338
+ ```
339
+
340
+ ---
341
+
342
+ ## 7. Complete Weight Inventory Table
343
+
344
+ ### 7.1 All Named Parameters
345
+
346
+ ```
347
+ EMBEDDING (47,185,920 params - 13.04%)
348
+ model.embed_tokens.weight (49152, 960) 47,185,920
349
+
350
+ LAYER 0 (9,832,320 params - 2.72%)
351
+ Attention (2,457,600):
352
+ model.layers.0.self_attn.q_proj.weight (960, 960) 921,600
353
+ model.layers.0.self_attn.k_proj.weight (320, 960) 307,200
354
+ model.layers.0.self_attn.v_proj.weight (320, 960) 307,200
355
+ model.layers.0.self_attn.o_proj.weight (960, 960) 921,600
356
+ MLP (7,372,800):
357
+ model.layers.0.mlp.gate_proj.weight (2560, 960) 2,457,600
358
+ model.layers.0.mlp.up_proj.weight (2560, 960) 2,457,600
359
+ model.layers.0.mlp.down_proj.weight (960, 2560) 2,457,600
360
+ Norms (1,920):
361
+ model.layers.0.input_layernorm.weight (960,) 960
362
+ model.layers.0.post_attention_layernorm.weight (960,) 960
363
+
364
+ [Layers 1-31 follow identical structure, each with 9,832,320 params]
365
+
366
+ FINAL NORM (960 params - 0.00%)
367
+ model.norm.weight (960,) 960
368
+
369
+ LM HEAD (tied with embed_tokens)
370
+ lm_head.weight (49152, 960) [shared]
371
+ ```
372
+
373
+ ### 7.2 Summary by Component Type
374
+
375
+ | Component Type | Count | Params Each | Total Params |
376
+ |----------------|-------|-------------|--------------|
377
+ | Embedding | 1 | 47,185,920 | 47,185,920 |
378
+ | Q Projection | 32 | 921,600 | 29,491,200 |
379
+ | K Projection | 32 | 307,200 | 9,830,400 |
380
+ | V Projection | 32 | 307,200 | 9,830,400 |
381
+ | O Projection | 32 | 921,600 | 29,491,200 |
382
+ | Gate Projection | 32 | 2,457,600 | 78,643,200 |
383
+ | Up Projection | 32 | 2,457,600 | 78,643,200 |
384
+ | Down Projection | 32 | 2,457,600 | 78,643,200 |
385
+ | Input LayerNorm | 32 | 960 | 30,720 |
386
+ | Post-Attn LayerNorm | 32 | 960 | 30,720 |
387
+ | Final LayerNorm | 1 | 960 | 960 |
388
+ | **Total** | | | **361,821,120** |
389
+
390
+ ---
391
+
392
+ ## 8. Configuration Reference
393
+
394
+ Complete model configuration from HuggingFace:
395
+
396
+ ```python
397
+ {
398
+ "architectures": ["LlamaForCausalLM"],
399
+ "attention_bias": False,
400
+ "attention_dropout": 0.0,
401
+ "bos_token_id": 1,
402
+ "eos_token_id": 2,
403
+ "pad_token_id": 2,
404
+ "head_dim": 64,
405
+ "hidden_act": "silu",
406
+ "hidden_size": 960,
407
+ "initializer_range": 0.02,
408
+ "intermediate_size": 2560,
409
+ "max_position_embeddings": 8192,
410
+ "mlp_bias": False,
411
+ "model_type": "llama",
412
+ "num_attention_heads": 15,
413
+ "num_hidden_layers": 32,
414
+ "num_key_value_heads": 5,
415
+ "pretraining_tp": 1,
416
+ "rms_norm_eps": 1e-05,
417
+ "rope_interleaved": False,
418
+ "rope_theta": 100000,
419
+ "tie_word_embeddings": True,
420
+ "use_cache": True,
421
+ "vocab_size": 49152
422
+ }
423
+ ```
424
+
425
+ ---
426
+
427
+ ## 9. Appendix: PyTorch Model Structure
428
+
429
+ ```
430
+ LlamaForCausalLM(
431
+ (model): LlamaModel(
432
+ (embed_tokens): Embedding(49152, 960, padding_idx=2)
433
+ (layers): ModuleList(
434
+ (0-31): 32 x LlamaDecoderLayer(
435
+ (self_attn): LlamaAttention(
436
+ (q_proj): Linear(in_features=960, out_features=960, bias=False)
437
+ (k_proj): Linear(in_features=960, out_features=320, bias=False)
438
+ (v_proj): Linear(in_features=960, out_features=320, bias=False)
439
+ (o_proj): Linear(in_features=960, out_features=960, bias=False)
440
+ )
441
+ (mlp): LlamaMLP(
442
+ (gate_proj): Linear(in_features=960, out_features=2560, bias=False)
443
+ (up_proj): Linear(in_features=960, out_features=2560, bias=False)
444
+ (down_proj): Linear(in_features=2560, out_features=960, bias=False)
445
+ (act_fn): SiLUActivation()
446
+ )
447
+ (input_layernorm): LlamaRMSNorm((960,), eps=1e-05)
448
+ (post_attention_layernorm): LlamaRMSNorm((960,), eps=1e-05)
449
+ )
450
+ )
451
+ (norm): LlamaRMSNorm((960,), eps=1e-05)
452
+ (rotary_emb): LlamaRotaryEmbedding()
453
+ )
454
+ (lm_head): Linear(in_features=960, out_features=49152, bias=False)
455
+ )
456
+ ```
llm_integration/analyze_smollm2.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SmolLM2-360M-Instruct Architecture Analysis
3
+ For 8bit-threshold-computer LLM Integration Project
4
+ """
5
+
6
+ import torch
7
+ from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
8
+ from collections import defaultdict
9
+ import json
10
+
11
+ def analyze_smollm2():
12
+ model_name = "HuggingFaceTB/SmolLM2-360M-Instruct"
13
+
14
+ print("=" * 80)
15
+ print("SmolLM2-360M-Instruct Architecture Analysis")
16
+ print("=" * 80)
17
+
18
+ # Load config first
19
+ print("\n[1] Loading model configuration...")
20
+ config = AutoConfig.from_pretrained(model_name)
21
+ print(f"Config loaded: {type(config).__name__}")
22
+
23
+ # Load tokenizer
24
+ print("\n[2] Loading tokenizer...")
25
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
26
+ print(f"Tokenizer loaded: {type(tokenizer).__name__}")
27
+
28
+ # Load model with hidden states output
29
+ print("\n[3] Loading model with output_hidden_states=True...")
30
+ model = AutoModelForCausalLM.from_pretrained(
31
+ model_name,
32
+ output_hidden_states=True,
33
+ torch_dtype=torch.float32
34
+ )
35
+ model.eval()
36
+ print(f"Model loaded: {type(model).__name__}")
37
+
38
+ # ========================================================================
39
+ # ARCHITECTURE CENSUS
40
+ # ========================================================================
41
+ print("\n" + "=" * 80)
42
+ print("ARCHITECTURE CENSUS")
43
+ print("=" * 80)
44
+
45
+ print("\n--- Model Configuration ---")
46
+ config_dict = config.to_dict()
47
+ for key, value in sorted(config_dict.items()):
48
+ print(f" {key}: {value}")
49
+
50
+ print("\n--- Key Architecture Parameters ---")
51
+ print(f" Model type: {config.model_type}")
52
+ print(f" Vocabulary size: {config.vocab_size}")
53
+ print(f" Hidden size: {config.hidden_size}")
54
+ print(f" Intermediate size: {config.intermediate_size}")
55
+ print(f" Number of hidden layers: {config.num_hidden_layers}")
56
+ print(f" Number of attention heads: {config.num_attention_heads}")
57
+ print(f" Number of KV heads: {getattr(config, 'num_key_value_heads', config.num_attention_heads)}")
58
+ print(f" Head dimension: {config.hidden_size // config.num_attention_heads}")
59
+ print(f" Max position embeddings: {config.max_position_embeddings}")
60
+ print(f" RMS norm epsilon: {getattr(config, 'rms_norm_eps', 'N/A')}")
61
+ print(f" Rope theta: {getattr(config, 'rope_theta', 'N/A')}")
62
+ print(f" Tie word embeddings: {getattr(config, 'tie_word_embeddings', 'N/A')}")
63
+
64
+ # ========================================================================
65
+ # WEIGHT INVENTORY
66
+ # ========================================================================
67
+ print("\n" + "=" * 80)
68
+ print("WEIGHT INVENTORY")
69
+ print("=" * 80)
70
+
71
+ total_params = 0
72
+ param_groups = defaultdict(list)
73
+
74
+ for name, param in model.named_parameters():
75
+ total_params += param.numel()
76
+
77
+ # Group by component
78
+ if "embed_tokens" in name:
79
+ group = "Embedding"
80
+ elif "lm_head" in name:
81
+ group = "LM Head"
82
+ elif "norm" in name and "layers" not in name:
83
+ group = "Final Norm"
84
+ elif "layers" in name:
85
+ layer_num = name.split(".")[2]
86
+ if "self_attn" in name:
87
+ group = f"Layer {layer_num} - Attention"
88
+ elif "mlp" in name:
89
+ group = f"Layer {layer_num} - MLP"
90
+ elif "norm" in name:
91
+ group = f"Layer {layer_num} - Norms"
92
+ else:
93
+ group = f"Layer {layer_num} - Other"
94
+ else:
95
+ group = "Other"
96
+
97
+ param_groups[group].append({
98
+ "name": name,
99
+ "shape": tuple(param.shape),
100
+ "numel": param.numel(),
101
+ "dtype": str(param.dtype)
102
+ })
103
+
104
+ print(f"\n--- Total Parameters: {total_params:,} ---")
105
+ print(f" ({total_params / 1e6:.2f}M parameters)")
106
+
107
+ # Print by group
108
+ for group_name in sorted(param_groups.keys()):
109
+ params = param_groups[group_name]
110
+ group_total = sum(p["numel"] for p in params)
111
+ print(f"\n### {group_name} ({group_total:,} params, {group_total/total_params*100:.2f}%)")
112
+ for p in params:
113
+ print(f" {p['name']}")
114
+ print(f" Shape: {p['shape']}, Elements: {p['numel']:,}, Dtype: {p['dtype']}")
115
+
116
+ # ========================================================================
117
+ # TOKENIZATION ANALYSIS
118
+ # ========================================================================
119
+ print("\n" + "=" * 80)
120
+ print("TOKENIZATION ANALYSIS")
121
+ print("=" * 80)
122
+
123
+ test_input = "47 + 86"
124
+ print(f"\n--- Test Input: '{test_input}' ---")
125
+
126
+ tokens = tokenizer(test_input, return_tensors="pt")
127
+ input_ids = tokens["input_ids"][0]
128
+
129
+ print(f"\nInput IDs: {input_ids.tolist()}")
130
+ print(f"Number of tokens: {len(input_ids)}")
131
+
132
+ print("\nToken breakdown:")
133
+ for i, token_id in enumerate(input_ids):
134
+ token_str = tokenizer.decode([token_id])
135
+ print(f" Position {i}: ID={token_id.item():5d}, Token='{token_str}'")
136
+
137
+ # Additional tokenization tests
138
+ print("\n--- Additional Tokenization Tests ---")
139
+ test_cases = ["0", "1", "47", "86", "133", " + ", "="]
140
+ for tc in test_cases:
141
+ ids = tokenizer.encode(tc, add_special_tokens=False)
142
+ decoded = [tokenizer.decode([i]) for i in ids]
143
+ print(f" '{tc}' -> IDs: {ids}, Tokens: {decoded}")
144
+
145
+ # ========================================================================
146
+ # HIDDEN STATE ANALYSIS
147
+ # ========================================================================
148
+ print("\n" + "=" * 80)
149
+ print("HIDDEN STATE ANALYSIS")
150
+ print("=" * 80)
151
+
152
+ print(f"\n--- Running inference on '{test_input}' ---")
153
+
154
+ with torch.no_grad():
155
+ outputs = model(**tokens)
156
+
157
+ hidden_states = outputs.hidden_states
158
+ print(f"\nNumber of hidden state outputs: {len(hidden_states)}")
159
+ print("(This includes embedding output + each layer's output)")
160
+
161
+ print("\nHidden state shapes at each layer:")
162
+ for i, hs in enumerate(hidden_states):
163
+ layer_name = "Embedding" if i == 0 else f"Layer {i-1}"
164
+ print(f" {layer_name}: {tuple(hs.shape)}")
165
+ if i == 0:
166
+ print(f" (batch_size=1, seq_len={hs.shape[1]}, hidden_dim={hs.shape[2]})")
167
+
168
+ # Analyze hidden state statistics at different layers
169
+ print("\n--- Hidden State Statistics (per layer) ---")
170
+ for i, hs in enumerate(hidden_states):
171
+ layer_name = "Embedding" if i == 0 else f"Layer {i-1}"
172
+ hs_flat = hs.view(-1)
173
+ print(f" {layer_name}:")
174
+ print(f" Mean: {hs_flat.mean().item():.6f}")
175
+ print(f" Std: {hs_flat.std().item():.6f}")
176
+ print(f" Min: {hs_flat.min().item():.6f}")
177
+ print(f" Max: {hs_flat.max().item():.6f}")
178
+
179
+ # ========================================================================
180
+ # MODEL STRUCTURE DEEP DIVE
181
+ # ========================================================================
182
+ print("\n" + "=" * 80)
183
+ print("MODEL STRUCTURE DEEP DIVE")
184
+ print("=" * 80)
185
+
186
+ print("\n--- Model Architecture String ---")
187
+ print(model)
188
+
189
+ # ========================================================================
190
+ # SUMMARY DATA FOR REPORT
191
+ # ========================================================================
192
+ summary = {
193
+ "model_name": model_name,
194
+ "total_params": total_params,
195
+ "config": {
196
+ "vocab_size": config.vocab_size,
197
+ "hidden_size": config.hidden_size,
198
+ "intermediate_size": config.intermediate_size,
199
+ "num_hidden_layers": config.num_hidden_layers,
200
+ "num_attention_heads": config.num_attention_heads,
201
+ "num_kv_heads": getattr(config, 'num_key_value_heads', config.num_attention_heads),
202
+ "head_dim": config.hidden_size // config.num_attention_heads,
203
+ "max_position_embeddings": config.max_position_embeddings,
204
+ "rms_norm_eps": getattr(config, 'rms_norm_eps', None),
205
+ "rope_theta": getattr(config, 'rope_theta', None),
206
+ "tie_word_embeddings": getattr(config, 'tie_word_embeddings', None),
207
+ },
208
+ "tokenization": {
209
+ "test_input": test_input,
210
+ "token_ids": input_ids.tolist(),
211
+ "num_tokens": len(input_ids),
212
+ "tokens": [tokenizer.decode([tid]) for tid in input_ids]
213
+ },
214
+ "hidden_states": {
215
+ "num_outputs": len(hidden_states),
216
+ "shape": list(hidden_states[0].shape)
217
+ },
218
+ "param_groups": {k: {"count": len(v), "total": sum(p["numel"] for p in v)} for k, v in param_groups.items()}
219
+ }
220
+
221
+ # Save summary as JSON for report generation
222
+ with open("D:/8bit-threshold-computer/llm_integration/smollm2_analysis.json", "w") as f:
223
+ json.dump(summary, f, indent=2)
224
+
225
+ print("\n" + "=" * 80)
226
+ print("Analysis complete. Summary saved to smollm2_analysis.json")
227
+ print("=" * 80)
228
+
229
+ return summary, model, tokenizer, hidden_states, param_groups
230
+
231
+ if __name__ == "__main__":
232
+ summary, model, tokenizer, hidden_states, param_groups = analyze_smollm2()
llm_integration/model.py CHANGED
@@ -351,76 +351,158 @@ class Extractor(nn.Module):
351
 
352
  class PositionExtractor(nn.Module):
353
  """
354
- Position-specific extraction.
355
- Extracts operand A from first token positions, operand B from later positions.
356
- For "47 + 86": positions 0-2 for A, position 3-4 for op, positions 5-7 for B.
 
 
 
 
 
 
 
 
 
 
 
357
  """
358
 
 
 
 
 
 
 
 
 
 
 
 
 
359
  def __init__(self, hidden_dim: int = 960, intermediate_dim: int = 256):
360
  super().__init__()
 
361
 
362
  self.a_extractor = nn.Sequential(
363
- nn.Linear(hidden_dim * 3, intermediate_dim),
364
  nn.GELU(),
365
- nn.Linear(intermediate_dim, 8),
 
 
366
  )
367
 
368
  self.b_extractor = nn.Sequential(
369
- nn.Linear(hidden_dim * 3, intermediate_dim),
370
  nn.GELU(),
371
- nn.Linear(intermediate_dim, 8),
 
 
372
  )
373
 
374
- self.op_router = nn.Sequential(
375
- nn.Linear(hidden_dim * 2, intermediate_dim),
376
  nn.GELU(),
377
- nn.Linear(intermediate_dim, len(OPERATIONS)),
378
  )
379
 
380
- def forward(self, hidden: torch.Tensor, mask: torch.Tensor):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
381
  """
382
  Args:
383
  hidden: [batch, seq_len, hidden_dim]
384
- mask: [batch, seq_len]
 
385
 
386
  Returns:
387
- a_bits, b_bits, op_logits
 
 
388
  """
389
- batch_size, seq_len, hidden_dim = hidden.shape
 
390
 
391
- seq_lens = mask.sum(dim=1).long()
 
392
 
393
  a_features = []
394
  b_features = []
395
  op_features = []
 
396
 
397
  for i in range(batch_size):
398
- slen = seq_lens[i].item()
399
- start = seq_len - slen
 
 
 
 
 
 
400
 
401
- a_pos = hidden[i, start:start+3, :].reshape(-1)
402
- if a_pos.shape[0] < hidden_dim * 3:
403
- a_pos = F.pad(a_pos, (0, hidden_dim * 3 - a_pos.shape[0]))
 
 
 
 
404
 
405
- op_pos = hidden[i, start+3:start+5, :].reshape(-1)
406
- if op_pos.shape[0] < hidden_dim * 2:
407
- op_pos = F.pad(op_pos, (0, hidden_dim * 2 - op_pos.shape[0]))
408
 
409
- b_pos = hidden[i, start+5:start+8, :].reshape(-1)
410
- if b_pos.shape[0] < hidden_dim * 3:
411
- b_pos = F.pad(b_pos, (0, hidden_dim * 3 - b_pos.shape[0]))
412
 
413
- a_features.append(a_pos)
414
- b_features.append(b_pos)
415
- op_features.append(op_pos)
 
416
 
417
  a_features = torch.stack(a_features)
418
  b_features = torch.stack(b_features)
419
  op_features = torch.stack(op_features)
 
420
 
421
  a_logits = self.a_extractor(a_features)
422
  b_logits = self.b_extractor(b_features)
423
- op_logits = self.op_router(op_features)
424
 
425
  a_soft = torch.sigmoid(a_logits)
426
  b_soft = torch.sigmoid(b_logits)
@@ -429,7 +511,7 @@ class PositionExtractor(nn.Module):
429
  a_bits = a_hard - a_soft.detach() + a_soft
430
  b_bits = b_hard - b_soft.detach() + b_soft
431
 
432
- return a_bits, b_bits, op_logits
433
 
434
 
435
  class DigitExtractor(nn.Module):
@@ -589,8 +671,15 @@ class ArithmeticModel(nn.Module):
589
  print(f" Extractor params: {trainable_ext:,}", flush=True)
590
  print(f" Total trainable: {total_trainable:,}", flush=True)
591
 
592
- def get_hidden_states(self, texts: list[str]) -> tuple[torch.Tensor, torch.Tensor]:
593
- """Get hidden states from specified layer."""
 
 
 
 
 
 
 
594
  inputs = self.tokenizer(
595
  texts,
596
  return_tensors='pt',
@@ -607,8 +696,9 @@ class ArithmeticModel(nn.Module):
607
 
608
  hidden = outputs.hidden_states[self.extract_layer].float()
609
  mask = inputs.attention_mask.float()
 
610
 
611
- return hidden, mask
612
 
613
  def forward(self, texts: list[str]):
614
  """
@@ -617,16 +707,25 @@ class ArithmeticModel(nn.Module):
617
  Returns:
618
  result_bits, a_bits, b_bits, op_logits
619
  If digit_pred: also returns a_digit_logits, b_digit_logits
 
620
  """
621
- hidden, mask = self.get_hidden_states(texts)
622
 
623
- extractor_out = self.extractor(hidden, mask)
 
 
 
624
 
625
  if self.digit_pred:
626
  a_bits, b_bits, op_logits, a_digit_logits, b_digit_logits = extractor_out
 
 
 
 
627
  else:
628
  a_bits, b_bits, op_logits = extractor_out
629
  a_digit_logits, b_digit_logits = None, None
 
630
 
631
  op_probs = torch.softmax(op_logits, dim=-1)
632
 
@@ -634,6 +733,8 @@ class ArithmeticModel(nn.Module):
634
 
635
  if self.digit_pred:
636
  return result_bits, a_bits, b_bits, op_logits, a_digit_logits, b_digit_logits
 
 
637
  return result_bits, a_bits, b_bits, op_logits
638
 
639
  def trainable_parameters(self):
 
351
 
352
  class PositionExtractor(nn.Module):
353
  """
354
+ Position-specific extraction with dynamic operator detection.
355
+
356
+ Tokenization pattern for "A op B":
357
+ [A_digits...] [operator] [space] [B_digits...]
358
+
359
+ Examples:
360
+ "5 + 3" -> ['5', ' +', ' ', '3'] (positions: A=0, op=1, B=3)
361
+ "47 + 86" -> ['4', '7', ' +', ' ', '8', '6'] (positions: A=0-1, op=2, B=4-5)
362
+ "127 + 128" -> ['1','2','7',' +', ' ','1','2','8'] (positions: A=0-2, op=3, B=5-7)
363
+
364
+ Token IDs (SmolLM2):
365
+ Digits '0'-'9': 32-41
366
+ Operators: ' +'=1232, ' -'=731, ' *'=1672, ' >'=2986, ' <'=2067, ' =='=1758
367
+ Space: 216
368
  """
369
 
370
+ DIGIT_TOKENS = set(range(32, 42))
371
+ OPERATOR_TOKENS = {
372
+ 1232: 0, # ' +' -> add
373
+ 731: 1, # ' -' -> sub
374
+ 1672: 2, # ' *' -> mul
375
+ 2986: 3, # ' >' -> gt
376
+ 2067: 4, # ' <' -> lt
377
+ 1758: 5, # ' ==' -> eq
378
+ }
379
+ SPACE_TOKEN = 216
380
+ MAX_DIGITS = 3
381
+
382
  def __init__(self, hidden_dim: int = 960, intermediate_dim: int = 256):
383
  super().__init__()
384
+ self.hidden_dim = hidden_dim
385
 
386
  self.a_extractor = nn.Sequential(
387
+ nn.Linear(hidden_dim * self.MAX_DIGITS, intermediate_dim),
388
  nn.GELU(),
389
+ nn.Linear(intermediate_dim, intermediate_dim // 2),
390
+ nn.GELU(),
391
+ nn.Linear(intermediate_dim // 2, 8),
392
  )
393
 
394
  self.b_extractor = nn.Sequential(
395
+ nn.Linear(hidden_dim * self.MAX_DIGITS, intermediate_dim),
396
  nn.GELU(),
397
+ nn.Linear(intermediate_dim, intermediate_dim // 2),
398
+ nn.GELU(),
399
+ nn.Linear(intermediate_dim // 2, 8),
400
  )
401
 
402
+ self.op_extractor = nn.Sequential(
403
+ nn.Linear(hidden_dim, intermediate_dim // 2),
404
  nn.GELU(),
405
+ nn.Linear(intermediate_dim // 2, len(OPERATIONS)),
406
  )
407
 
408
+ def _find_operator_position(self, token_ids: torch.Tensor) -> tuple[int, int]:
409
+ """
410
+ Find operator token position and its operation index.
411
+
412
+ Args:
413
+ token_ids: [seq_len] tensor of token IDs
414
+
415
+ Returns:
416
+ (position, op_index) or (-1, -1) if not found
417
+ """
418
+ for pos, tid in enumerate(token_ids.tolist()):
419
+ if tid in self.OPERATOR_TOKENS:
420
+ return pos, self.OPERATOR_TOKENS[tid]
421
+ return -1, -1
422
+
423
+ def _extract_digit_features(self, hidden: torch.Tensor, start: int, end: int) -> torch.Tensor:
424
+ """
425
+ Extract and pad digit hidden states to fixed size.
426
+
427
+ Args:
428
+ hidden: [seq_len, hidden_dim]
429
+ start: start position (inclusive)
430
+ end: end position (exclusive)
431
+
432
+ Returns:
433
+ [hidden_dim * MAX_DIGITS] flattened features, zero-padded on the LEFT
434
+ (so units digit is always at the same position regardless of number length)
435
+ """
436
+ n_digits = end - start
437
+ features = torch.zeros(self.MAX_DIGITS * self.hidden_dim, device=hidden.device)
438
+
439
+ if n_digits > 0 and n_digits <= self.MAX_DIGITS:
440
+ digit_hidden = hidden[start:end, :].reshape(-1)
441
+ pad_size = (self.MAX_DIGITS - n_digits) * self.hidden_dim
442
+ features[pad_size:] = digit_hidden
443
+
444
+ return features
445
+
446
+ def forward(self, hidden: torch.Tensor, mask: torch.Tensor, token_ids: torch.Tensor = None):
447
  """
448
  Args:
449
  hidden: [batch, seq_len, hidden_dim]
450
+ mask: [batch, seq_len] attention mask
451
+ token_ids: [batch, seq_len] token IDs (required for operator detection)
452
 
453
  Returns:
454
+ a_bits: [batch, 8]
455
+ b_bits: [batch, 8]
456
+ op_logits: [batch, 6]
457
  """
458
+ if token_ids is None:
459
+ raise ValueError("PositionExtractor requires token_ids for operator detection")
460
 
461
+ batch_size, seq_len, hidden_dim = hidden.shape
462
+ device = hidden.device
463
 
464
  a_features = []
465
  b_features = []
466
  op_features = []
467
+ op_indices = []
468
 
469
  for i in range(batch_size):
470
+ seq_mask = mask[i].bool()
471
+ valid_len = seq_mask.sum().item()
472
+ start_pos = seq_len - valid_len
473
+
474
+ valid_tokens = token_ids[i, start_pos:]
475
+ valid_hidden = hidden[i, start_pos:, :]
476
+
477
+ op_pos, op_idx = self._find_operator_position(valid_tokens)
478
 
479
+ if op_pos == -1:
480
+ a_feat = torch.zeros(self.MAX_DIGITS * hidden_dim, device=device)
481
+ b_feat = torch.zeros(self.MAX_DIGITS * hidden_dim, device=device)
482
+ op_feat = torch.zeros(hidden_dim, device=device)
483
+ op_idx = 0
484
+ else:
485
+ a_feat = self._extract_digit_features(valid_hidden, 0, op_pos)
486
 
487
+ op_feat = valid_hidden[op_pos, :]
 
 
488
 
489
+ b_start = op_pos + 2 if (op_pos + 1 < valid_len and
490
+ valid_tokens[op_pos + 1].item() == self.SPACE_TOKEN) else op_pos + 1
491
+ b_feat = self._extract_digit_features(valid_hidden, b_start, valid_len)
492
 
493
+ a_features.append(a_feat)
494
+ b_features.append(b_feat)
495
+ op_features.append(op_feat)
496
+ op_indices.append(op_idx)
497
 
498
  a_features = torch.stack(a_features)
499
  b_features = torch.stack(b_features)
500
  op_features = torch.stack(op_features)
501
+ op_indices_tensor = torch.tensor(op_indices, device=device, dtype=torch.long)
502
 
503
  a_logits = self.a_extractor(a_features)
504
  b_logits = self.b_extractor(b_features)
505
+ op_logits = self.op_extractor(op_features)
506
 
507
  a_soft = torch.sigmoid(a_logits)
508
  b_soft = torch.sigmoid(b_logits)
 
511
  a_bits = a_hard - a_soft.detach() + a_soft
512
  b_bits = b_hard - b_soft.detach() + b_soft
513
 
514
+ return a_bits, b_bits, op_logits, op_indices_tensor
515
 
516
 
517
  class DigitExtractor(nn.Module):
 
671
  print(f" Extractor params: {trainable_ext:,}", flush=True)
672
  print(f" Total trainable: {total_trainable:,}", flush=True)
673
 
674
+ def get_hidden_states(self, texts: list[str]) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
675
+ """
676
+ Get hidden states from specified layer.
677
+
678
+ Returns:
679
+ hidden: [batch, seq_len, hidden_dim] hidden states
680
+ mask: [batch, seq_len] attention mask
681
+ token_ids: [batch, seq_len] input token IDs
682
+ """
683
  inputs = self.tokenizer(
684
  texts,
685
  return_tensors='pt',
 
696
 
697
  hidden = outputs.hidden_states[self.extract_layer].float()
698
  mask = inputs.attention_mask.float()
699
+ token_ids = inputs.input_ids
700
 
701
+ return hidden, mask, token_ids
702
 
703
  def forward(self, texts: list[str]):
704
  """
 
707
  Returns:
708
  result_bits, a_bits, b_bits, op_logits
709
  If digit_pred: also returns a_digit_logits, b_digit_logits
710
+ If position_extract: also returns op_indices (ground truth from tokenization)
711
  """
712
+ hidden, mask, token_ids = self.get_hidden_states(texts)
713
 
714
+ if self.position_extract:
715
+ extractor_out = self.extractor(hidden, mask, token_ids)
716
+ else:
717
+ extractor_out = self.extractor(hidden, mask)
718
 
719
  if self.digit_pred:
720
  a_bits, b_bits, op_logits, a_digit_logits, b_digit_logits = extractor_out
721
+ op_indices_from_tokens = None
722
+ elif self.position_extract:
723
+ a_bits, b_bits, op_logits, op_indices_from_tokens = extractor_out
724
+ a_digit_logits, b_digit_logits = None, None
725
  else:
726
  a_bits, b_bits, op_logits = extractor_out
727
  a_digit_logits, b_digit_logits = None, None
728
+ op_indices_from_tokens = None
729
 
730
  op_probs = torch.softmax(op_logits, dim=-1)
731
 
 
733
 
734
  if self.digit_pred:
735
  return result_bits, a_bits, b_bits, op_logits, a_digit_logits, b_digit_logits
736
+ if self.position_extract:
737
+ return result_bits, a_bits, b_bits, op_logits, op_indices_from_tokens
738
  return result_bits, a_bits, b_bits, op_logits
739
 
740
  def trainable_parameters(self):
llm_integration/smollm2_analysis.json ADDED
@@ -0,0 +1,439 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_name": "HuggingFaceTB/SmolLM2-360M-Instruct",
3
+ "total_params": 361821120,
4
+ "config": {
5
+ "vocab_size": 49152,
6
+ "hidden_size": 960,
7
+ "intermediate_size": 2560,
8
+ "num_hidden_layers": 32,
9
+ "num_attention_heads": 15,
10
+ "num_kv_heads": 5,
11
+ "head_dim": 64,
12
+ "max_position_embeddings": 8192,
13
+ "rms_norm_eps": 1e-05,
14
+ "rope_theta": 100000,
15
+ "tie_word_embeddings": true
16
+ },
17
+ "tokenization": {
18
+ "test_input": "47 + 86",
19
+ "token_ids": [
20
+ 36,
21
+ 39,
22
+ 1232,
23
+ 216,
24
+ 40,
25
+ 38
26
+ ],
27
+ "num_tokens": 6,
28
+ "tokens": [
29
+ "4",
30
+ "7",
31
+ " +",
32
+ " ",
33
+ "8",
34
+ "6"
35
+ ]
36
+ },
37
+ "hidden_states": {
38
+ "num_outputs": 33,
39
+ "shape": [
40
+ 1,
41
+ 6,
42
+ 960
43
+ ]
44
+ },
45
+ "param_groups": {
46
+ "Embedding": {
47
+ "count": 1,
48
+ "total": 47185920
49
+ },
50
+ "Layer 0 - Attention": {
51
+ "count": 4,
52
+ "total": 2457600
53
+ },
54
+ "Layer 0 - MLP": {
55
+ "count": 3,
56
+ "total": 7372800
57
+ },
58
+ "Layer 0 - Norms": {
59
+ "count": 2,
60
+ "total": 1920
61
+ },
62
+ "Layer 1 - Attention": {
63
+ "count": 4,
64
+ "total": 2457600
65
+ },
66
+ "Layer 1 - MLP": {
67
+ "count": 3,
68
+ "total": 7372800
69
+ },
70
+ "Layer 1 - Norms": {
71
+ "count": 2,
72
+ "total": 1920
73
+ },
74
+ "Layer 2 - Attention": {
75
+ "count": 4,
76
+ "total": 2457600
77
+ },
78
+ "Layer 2 - MLP": {
79
+ "count": 3,
80
+ "total": 7372800
81
+ },
82
+ "Layer 2 - Norms": {
83
+ "count": 2,
84
+ "total": 1920
85
+ },
86
+ "Layer 3 - Attention": {
87
+ "count": 4,
88
+ "total": 2457600
89
+ },
90
+ "Layer 3 - MLP": {
91
+ "count": 3,
92
+ "total": 7372800
93
+ },
94
+ "Layer 3 - Norms": {
95
+ "count": 2,
96
+ "total": 1920
97
+ },
98
+ "Layer 4 - Attention": {
99
+ "count": 4,
100
+ "total": 2457600
101
+ },
102
+ "Layer 4 - MLP": {
103
+ "count": 3,
104
+ "total": 7372800
105
+ },
106
+ "Layer 4 - Norms": {
107
+ "count": 2,
108
+ "total": 1920
109
+ },
110
+ "Layer 5 - Attention": {
111
+ "count": 4,
112
+ "total": 2457600
113
+ },
114
+ "Layer 5 - MLP": {
115
+ "count": 3,
116
+ "total": 7372800
117
+ },
118
+ "Layer 5 - Norms": {
119
+ "count": 2,
120
+ "total": 1920
121
+ },
122
+ "Layer 6 - Attention": {
123
+ "count": 4,
124
+ "total": 2457600
125
+ },
126
+ "Layer 6 - MLP": {
127
+ "count": 3,
128
+ "total": 7372800
129
+ },
130
+ "Layer 6 - Norms": {
131
+ "count": 2,
132
+ "total": 1920
133
+ },
134
+ "Layer 7 - Attention": {
135
+ "count": 4,
136
+ "total": 2457600
137
+ },
138
+ "Layer 7 - MLP": {
139
+ "count": 3,
140
+ "total": 7372800
141
+ },
142
+ "Layer 7 - Norms": {
143
+ "count": 2,
144
+ "total": 1920
145
+ },
146
+ "Layer 8 - Attention": {
147
+ "count": 4,
148
+ "total": 2457600
149
+ },
150
+ "Layer 8 - MLP": {
151
+ "count": 3,
152
+ "total": 7372800
153
+ },
154
+ "Layer 8 - Norms": {
155
+ "count": 2,
156
+ "total": 1920
157
+ },
158
+ "Layer 9 - Attention": {
159
+ "count": 4,
160
+ "total": 2457600
161
+ },
162
+ "Layer 9 - MLP": {
163
+ "count": 3,
164
+ "total": 7372800
165
+ },
166
+ "Layer 9 - Norms": {
167
+ "count": 2,
168
+ "total": 1920
169
+ },
170
+ "Layer 10 - Attention": {
171
+ "count": 4,
172
+ "total": 2457600
173
+ },
174
+ "Layer 10 - MLP": {
175
+ "count": 3,
176
+ "total": 7372800
177
+ },
178
+ "Layer 10 - Norms": {
179
+ "count": 2,
180
+ "total": 1920
181
+ },
182
+ "Layer 11 - Attention": {
183
+ "count": 4,
184
+ "total": 2457600
185
+ },
186
+ "Layer 11 - MLP": {
187
+ "count": 3,
188
+ "total": 7372800
189
+ },
190
+ "Layer 11 - Norms": {
191
+ "count": 2,
192
+ "total": 1920
193
+ },
194
+ "Layer 12 - Attention": {
195
+ "count": 4,
196
+ "total": 2457600
197
+ },
198
+ "Layer 12 - MLP": {
199
+ "count": 3,
200
+ "total": 7372800
201
+ },
202
+ "Layer 12 - Norms": {
203
+ "count": 2,
204
+ "total": 1920
205
+ },
206
+ "Layer 13 - Attention": {
207
+ "count": 4,
208
+ "total": 2457600
209
+ },
210
+ "Layer 13 - MLP": {
211
+ "count": 3,
212
+ "total": 7372800
213
+ },
214
+ "Layer 13 - Norms": {
215
+ "count": 2,
216
+ "total": 1920
217
+ },
218
+ "Layer 14 - Attention": {
219
+ "count": 4,
220
+ "total": 2457600
221
+ },
222
+ "Layer 14 - MLP": {
223
+ "count": 3,
224
+ "total": 7372800
225
+ },
226
+ "Layer 14 - Norms": {
227
+ "count": 2,
228
+ "total": 1920
229
+ },
230
+ "Layer 15 - Attention": {
231
+ "count": 4,
232
+ "total": 2457600
233
+ },
234
+ "Layer 15 - MLP": {
235
+ "count": 3,
236
+ "total": 7372800
237
+ },
238
+ "Layer 15 - Norms": {
239
+ "count": 2,
240
+ "total": 1920
241
+ },
242
+ "Layer 16 - Attention": {
243
+ "count": 4,
244
+ "total": 2457600
245
+ },
246
+ "Layer 16 - MLP": {
247
+ "count": 3,
248
+ "total": 7372800
249
+ },
250
+ "Layer 16 - Norms": {
251
+ "count": 2,
252
+ "total": 1920
253
+ },
254
+ "Layer 17 - Attention": {
255
+ "count": 4,
256
+ "total": 2457600
257
+ },
258
+ "Layer 17 - MLP": {
259
+ "count": 3,
260
+ "total": 7372800
261
+ },
262
+ "Layer 17 - Norms": {
263
+ "count": 2,
264
+ "total": 1920
265
+ },
266
+ "Layer 18 - Attention": {
267
+ "count": 4,
268
+ "total": 2457600
269
+ },
270
+ "Layer 18 - MLP": {
271
+ "count": 3,
272
+ "total": 7372800
273
+ },
274
+ "Layer 18 - Norms": {
275
+ "count": 2,
276
+ "total": 1920
277
+ },
278
+ "Layer 19 - Attention": {
279
+ "count": 4,
280
+ "total": 2457600
281
+ },
282
+ "Layer 19 - MLP": {
283
+ "count": 3,
284
+ "total": 7372800
285
+ },
286
+ "Layer 19 - Norms": {
287
+ "count": 2,
288
+ "total": 1920
289
+ },
290
+ "Layer 20 - Attention": {
291
+ "count": 4,
292
+ "total": 2457600
293
+ },
294
+ "Layer 20 - MLP": {
295
+ "count": 3,
296
+ "total": 7372800
297
+ },
298
+ "Layer 20 - Norms": {
299
+ "count": 2,
300
+ "total": 1920
301
+ },
302
+ "Layer 21 - Attention": {
303
+ "count": 4,
304
+ "total": 2457600
305
+ },
306
+ "Layer 21 - MLP": {
307
+ "count": 3,
308
+ "total": 7372800
309
+ },
310
+ "Layer 21 - Norms": {
311
+ "count": 2,
312
+ "total": 1920
313
+ },
314
+ "Layer 22 - Attention": {
315
+ "count": 4,
316
+ "total": 2457600
317
+ },
318
+ "Layer 22 - MLP": {
319
+ "count": 3,
320
+ "total": 7372800
321
+ },
322
+ "Layer 22 - Norms": {
323
+ "count": 2,
324
+ "total": 1920
325
+ },
326
+ "Layer 23 - Attention": {
327
+ "count": 4,
328
+ "total": 2457600
329
+ },
330
+ "Layer 23 - MLP": {
331
+ "count": 3,
332
+ "total": 7372800
333
+ },
334
+ "Layer 23 - Norms": {
335
+ "count": 2,
336
+ "total": 1920
337
+ },
338
+ "Layer 24 - Attention": {
339
+ "count": 4,
340
+ "total": 2457600
341
+ },
342
+ "Layer 24 - MLP": {
343
+ "count": 3,
344
+ "total": 7372800
345
+ },
346
+ "Layer 24 - Norms": {
347
+ "count": 2,
348
+ "total": 1920
349
+ },
350
+ "Layer 25 - Attention": {
351
+ "count": 4,
352
+ "total": 2457600
353
+ },
354
+ "Layer 25 - MLP": {
355
+ "count": 3,
356
+ "total": 7372800
357
+ },
358
+ "Layer 25 - Norms": {
359
+ "count": 2,
360
+ "total": 1920
361
+ },
362
+ "Layer 26 - Attention": {
363
+ "count": 4,
364
+ "total": 2457600
365
+ },
366
+ "Layer 26 - MLP": {
367
+ "count": 3,
368
+ "total": 7372800
369
+ },
370
+ "Layer 26 - Norms": {
371
+ "count": 2,
372
+ "total": 1920
373
+ },
374
+ "Layer 27 - Attention": {
375
+ "count": 4,
376
+ "total": 2457600
377
+ },
378
+ "Layer 27 - MLP": {
379
+ "count": 3,
380
+ "total": 7372800
381
+ },
382
+ "Layer 27 - Norms": {
383
+ "count": 2,
384
+ "total": 1920
385
+ },
386
+ "Layer 28 - Attention": {
387
+ "count": 4,
388
+ "total": 2457600
389
+ },
390
+ "Layer 28 - MLP": {
391
+ "count": 3,
392
+ "total": 7372800
393
+ },
394
+ "Layer 28 - Norms": {
395
+ "count": 2,
396
+ "total": 1920
397
+ },
398
+ "Layer 29 - Attention": {
399
+ "count": 4,
400
+ "total": 2457600
401
+ },
402
+ "Layer 29 - MLP": {
403
+ "count": 3,
404
+ "total": 7372800
405
+ },
406
+ "Layer 29 - Norms": {
407
+ "count": 2,
408
+ "total": 1920
409
+ },
410
+ "Layer 30 - Attention": {
411
+ "count": 4,
412
+ "total": 2457600
413
+ },
414
+ "Layer 30 - MLP": {
415
+ "count": 3,
416
+ "total": 7372800
417
+ },
418
+ "Layer 30 - Norms": {
419
+ "count": 2,
420
+ "total": 1920
421
+ },
422
+ "Layer 31 - Attention": {
423
+ "count": 4,
424
+ "total": 2457600
425
+ },
426
+ "Layer 31 - MLP": {
427
+ "count": 3,
428
+ "total": 7372800
429
+ },
430
+ "Layer 31 - Norms": {
431
+ "count": 2,
432
+ "total": 1920
433
+ },
434
+ "Final Norm": {
435
+ "count": 1,
436
+ "total": 960
437
+ }
438
+ }
439
+ }
llm_integration/train.py CHANGED
@@ -398,7 +398,9 @@ def evaluate_llm(model, n_samples: int = 500):
398
  text, a, b, op, expected = generate_problem()
399
 
400
  with torch.no_grad():
401
- result_bits, a_bits, b_bits, op_logits = model([text])
 
 
402
 
403
  pred_result = bits_to_int(result_bits[0])
404
  pred_op = OPERATIONS[op_logits[0].argmax().item()]
@@ -502,7 +504,8 @@ def train_llm(epochs: int = 100, batch_size: int = 256, lr: float = 3e-4,
502
 
503
  optimizer.zero_grad()
504
 
505
- pred_bits, a_bits, b_bits, op_logits = model(batch_texts)
 
506
 
507
  loss, losses = compute_llm_loss(
508
  pred_bits, a_bits, b_bits, op_logits,
@@ -556,7 +559,8 @@ def train_llm(epochs: int = 100, batch_size: int = 256, lr: float = 3e-4,
556
  for _ in range(10):
557
  text, a, b, op, expected = generate_problem()
558
  with torch.no_grad():
559
- result_bits, a_bits, b_bits, op_logits = model([text])
 
560
  pred = bits_to_int(result_bits[0])
561
  pred_a = bits_to_int(a_bits[0])
562
  pred_b = bits_to_int(b_bits[0])
 
398
  text, a, b, op, expected = generate_problem()
399
 
400
  with torch.no_grad():
401
+ outputs = model([text])
402
+ result_bits = outputs[0]
403
+ op_logits = outputs[3]
404
 
405
  pred_result = bits_to_int(result_bits[0])
406
  pred_op = OPERATIONS[op_logits[0].argmax().item()]
 
504
 
505
  optimizer.zero_grad()
506
 
507
+ outputs = model(batch_texts)
508
+ pred_bits, a_bits, b_bits, op_logits = outputs[0], outputs[1], outputs[2], outputs[3]
509
 
510
  loss, losses = compute_llm_loss(
511
  pred_bits, a_bits, b_bits, op_logits,
 
559
  for _ in range(10):
560
  text, a, b, op, expected = generate_problem()
561
  with torch.no_grad():
562
+ outputs = model([text])
563
+ result_bits, a_bits, b_bits, op_logits = outputs[0], outputs[1], outputs[2], outputs[3]
564
  pred = bits_to_int(result_bits[0])
565
  pred_a = bits_to_int(a_bits[0])
566
  pred_b = bits_to_int(b_bits[0])