RobbiePasquale commited on
Commit
8793340
·
verified ·
1 Parent(s): 155f547

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +51 -130
README.md CHANGED
@@ -315,104 +315,6 @@ After each epoch, the model is evaluated on the validation set, computing the av
315
  ### Checkpoints
316
  At the end of each epoch, the model saves checkpoints of all components, enabling easy resumption or further fine-tuning as needed.
317
 
318
-
319
- ## Language Model Architecture
320
-
321
- ### Transformer Architecture
322
-
323
- The Transformer architecture is foundational to the LightBulb model, facilitating efficient sequence processing through self-attention mechanisms and feedforward networks enhanced by Mixture of Experts (MoE).
324
-
325
- #### TransformerBlock
326
-
327
- Each `TransformerBlock` consists of the following components:
328
-
329
- 1. **Self-Attention (`self_attention`)**
330
- 2. **Layer Normalization (`norm1`)**
331
- 3. **Cross-Attention (`cross_attention`)**
332
- 4. **Layer Normalization (`norm2`)**
333
- 5. **Mixture of Experts (`moe`)**
334
- 6. **Layer Normalization (`norm3`)**
335
-
336
- **Mathematical Operations:**
337
-
338
- 1. **Self-Attention:**
339
- \[
340
- \text{Attn}_{\text{self}} = \text{SelfAttention}(Q, K, V) = \text{softmax}\left(\frac{QK^\top}{\sqrt{d_k}}\right)V
341
- \]
342
-
343
- 2. **Residual Connection and Layer Norm:**
344
- \[
345
- x = \text{LayerNorm}(x + \text{Attn}_{\text{self}})
346
- \]
347
-
348
- 3. **Cross-Attention (if applicable):**
349
- \[
350
- \text{Attn}_{\text{cross}} = \text{CrossAttention}(Q, K_{\text{enc}}, V_{\text{enc}}) = \text{softmax}\left(\frac{QK_{\text{enc}}^\top}{\sqrt{d_k}}\right)V_{\text{enc}}
351
- \]
352
- \[
353
- x = \text{LayerNorm}(x + \text{Attn}_{\text{cross}})
354
- \]
355
-
356
- 4. **Mixture of Experts:**
357
- \[
358
- \text{MoE}_{\text{output}} = \sum_{i=1}^k g_i(x) \cdot \text{Expert}_i(x)
359
- \]
360
-
361
- 5. **Residual Connection and Layer Norm:**
362
- \[
363
- x = \text{LayerNorm}(x + \text{MoE}_{\text{output}})
364
- \]
365
-
366
- **Key Parameters:**
367
-
368
- - \( d_{\text{model}} \): Dimensionality of the model embeddings.
369
- - \( d_k \): Dimensionality of the key vectors in attention.
370
- - \( \text{num\_heads} \): Number of attention heads.
371
- - \( \text{num\_experts} \): Number of experts in the MoE layer.
372
- - \( \text{top\_k} \): Number of top experts to activate in MoE.
373
- - \( \text{dropout} \): Dropout rate for regularization.
374
-
375
- #### Transformer
376
-
377
- The `Transformer` class orchestrates multiple `TransformerBlock` instances within encoder and decoder stacks.
378
-
379
- **Components:**
380
-
381
- 1. **Embedding Layer:**
382
- \[
383
- E = \text{Embedding}(input\_ids) \times \sqrt{d_{\text{model}}}
384
- \]
385
-
386
- 2. **Rotary Positional Encoding (`rotary_positional_encoding`):**
387
- - Injects positional information by rotating the embeddings based on token positions.
388
-
389
- 3. **Encoder and Decoder Layers:**
390
- - Multiple `TransformerBlock` instances processing the embedded inputs.
391
-
392
- 4. **Output Layer:**
393
- \[
394
- \text{Output} = \text{Linear}(d_{\text{model}}, \text{output\_dim})(\text{Decoder Output})
395
- \]
396
-
397
- 5. **Beam Search with Multi-Token Prediction (`generate_with_beam_search`):**
398
- - Generates sequences by predicting multiple tokens at each step, maintaining a beam of top candidates.
399
-
400
- **Forward Pass:**
401
-
402
- \[
403
- \begin{align*}
404
- \text{Encoder:} & \quad X_{\text{enc}} = \text{Embedding}(src) \times \sqrt{d_{\text{model}}} \\
405
- & \quad X_{\text{enc}} = \text{RotaryPositionalEncoding}(X_{\text{enc}}) \\
406
- & \quad X_{\text{enc}} = \text{EncoderLayers}(X_{\text{enc}}) \\
407
- \\
408
- \text{Decoder:} & \quad X_{\text{dec}} = \text{Embedding}(tgt) \times \sqrt{d_{\text{model}}} \\
409
- & \quad X_{\text{dec}} = \text{RotaryPositionalEncoding}(X_{\text{dec}}) \\
410
- & \quad X_{\text{dec}} = \text{DecoderLayers}(X_{\text{dec}}, X_{\text{enc}}) \\
411
- \\
412
- \text{Output:} & \quad \text{output} = \text{Linear}(X_{\text{dec}})
413
- \end{align*}
414
- \]
415
-
416
  ---
417
 
418
  ### World Model Components
@@ -425,10 +327,11 @@ The World Model encapsulates components that model state representations, dynami
425
  Transforms the transformer's output embeddings into a compact state representation suitable for modeling and prediction tasks.
426
 
427
  **Mathematical Operation:**
 
428
  \[
429
  \text{State} = \text{LayerNorm}\left(\text{Linear}(d_{\text{model}} \rightarrow d_{\text{state}})\left(\text{Linear}(vocab\_dim \rightarrow d_{\text{model}})(\text{Transformer Output})\right)\right)
430
  \]
431
-
432
  **Explanation:**
433
  Sequential linear transformations project high-dimensional embeddings into a lower-dimensional state space, followed by layer normalization for stability.
434
 
@@ -438,9 +341,11 @@ Sequential linear transformations project high-dimensional embeddings into a low
438
  Models how the state evolves in response to actions (thoughts) taken by the model.
439
 
440
  **Mathematical Operation:**
 
441
  \[
442
  \text{Next State} = \text{DynamicsNetwork}(\text{Current State}, \text{Action Embedding})
443
  \]
 
444
 
445
  **Explanation:**
446
  Predicts the subsequent state by combining the current state representation with an encoded action, effectively simulating the consequences of actions within the Tree of Thought.
@@ -451,10 +356,11 @@ Predicts the subsequent state by combining the current state representation with
451
  Predicts policy logits (action probabilities) and value estimates (state evaluations) based on the current state.
452
 
453
  **Mathematical Operation:**
 
454
  \[
455
  (\text{Policy Logits}, \text{Value Estimate}) = \text{PredictionNetwork}(\text{State})
456
  \]
457
-
458
  **Explanation:**
459
  - **Policy Logits:** Used to derive action probabilities via softmax.
460
  - **Value Estimate:** Represents the expected reward or quality of the current state.
@@ -465,10 +371,11 @@ Predicts policy logits (action probabilities) and value estimates (state evaluat
465
  Encodes discrete actions (thoughts) into continuous embeddings compatible with the DynamicsNetwork.
466
 
467
  **Mathematical Operation:**
 
468
  \[
469
  \text{Action Embedding} = \text{ActionEncoder}(\text{Action Index})
470
  \]
471
-
472
  **Explanation:**
473
  Converts action indices into dense vector representations, facilitating their integration into state transition modeling.
474
 
@@ -492,10 +399,11 @@ Represents a node in the Tree of Thought, corresponding to a specific action or
492
  **Mathematical Representation:**
493
 
494
  Each `ThoughtNode` can be represented as a tree node in a directed graph:
 
495
  \[
496
  \text{ThoughtNode} = (\text{name}, \{\text{children}\})
497
  \]
498
-
499
  #### State
500
 
501
  **Function:**
@@ -509,39 +417,47 @@ Represents the current state within the MCTS and Tree of Thought framework.
509
  - `thought_node`: Reference to the current `ThoughtNode` in the Tree of Thought.
510
 
511
  **Action Application (`apply_action`):**
512
-
513
  \[
514
  \text{Next State} = \text{DynamicsNetwork}(\text{Current State}, \text{Action Embedding})
515
  \]
516
  \[
517
  \text{New Representation} = \text{Concat}(\text{Current Representation}, \text{Next State} \rightarrow \text{unsqueeze}(1))
518
  \]
519
-
520
  **Procedure:**
521
 
522
  1. **Action Encoding:**
 
 
523
  \[
524
  \text{Action Index} = \text{Index of Action}
525
  \]
526
  \[
527
  \text{Action Embedding} = \text{ActionEncoder}(\text{Action Index})
528
  \]
529
-
530
  2. **State Extraction:**
 
 
531
  \[
532
  \text{Current State} = \text{representation}[:, -1, :]
533
  \]
534
-
535
  3. **State Transition:**
 
 
536
  \[
537
  \text{Next State Representation} = \text{DynamicsNetwork}(\text{Current State}, \text{Action Embedding})
538
  \]
539
-
540
  4. **Representation Update:**
 
 
541
  \[
542
  \text{New Representation} = \text{Concat}(\text{representation}, \text{Next State Representation} \times \text{unsqueeze}(1))
543
  \]
544
-
545
  5. **Thought Node Update:**
546
  - Navigate to the child `ThoughtNode` corresponding to the applied action.
547
 
@@ -571,10 +487,11 @@ Represents a node in the MCTS search tree, encapsulating a specific state in the
571
  **Mathematical Representation:**
572
 
573
  Each `MCTSNode` can be considered as:
 
574
  \[
575
  \text{MCTSNode} = (\text{state}, \text{parent}, \text{action}, \{\text{children}\}, \text{visit\_count}, \text{value\_sum}, \text{prior}, \text{entropy}, \text{variance})
576
  \]
577
-
578
  #### MCTS Algorithm
579
 
580
  The `MCTS` class implements the Monte Carlo Tree Search algorithm tailored to the LightBulb model's architecture.
@@ -604,9 +521,11 @@ The `MCTS` class implements the Monte Carlo Tree Search algorithm tailored to th
604
  - Add the candidate sequence to `all_candidates`.
605
  - **Beam Pruning:**
606
  - Sort all candidates based on a combined score:
 
607
  \[
608
  \text{Combined Score} = \text{Score} - 0.1 \times \text{Entropy} + 0.05 \times \text{Variance}
609
  \]
 
610
  - Retain the top `beam_size` candidates for the next iteration.
611
  4. **Result Extraction:**
612
  - After completing iterations, select the best action sequence from the final beam.
@@ -620,6 +539,7 @@ The `MCTS` class implements the Monte Carlo Tree Search algorithm tailored to th
620
  - Calculate entropy and variance of the policy distribution.
621
  - Expand the node by creating child nodes based on the Tree of Thought and assign priors from policy probabilities.
622
  - **Mathematical Operations:**
 
623
  \[
624
  (\text{Policy Logits}, \text{Value Estimate}) = \text{PredictionNetwork}(\text{State})
625
  \]
@@ -632,19 +552,22 @@ The `MCTS` class implements the Monte Carlo Tree Search algorithm tailored to th
632
  \[
633
  \text{Variance} = \text{Var}(P)
634
  \]
635
-
636
  4. **Backpropagation (`backpropagate`):**
637
  - **Function:** Updates the `visit_count` and `value_sum` for nodes along the path from the evaluated node back to the root.
638
  - **Procedure:**
 
639
  \[
640
  \text{For each node in the path:} \\
641
  \quad \text{node.visit\_count} \mathrel{+}= 1 \\
642
  \quad \text{node.value\_sum} \mathrel{+}= \text{Value Estimate}
643
  \]
644
-
645
  5. **Upper Confidence Bound (UCB) Score (`ucb_score`):**
646
  - **Function:** Balances exploration of less-visited nodes and exploitation of high-value nodes.
647
  - **Mathematical Operation:**
 
 
648
  \[
649
  \text{UCB Score} = \text{Average Value} + \text{Exploration Term} + \text{Entropy Term} + \text{Variance Term}
650
  \]
@@ -661,7 +584,7 @@ The `MCTS` class implements the Monte Carlo Tree Search algorithm tailored to th
661
  \[
662
  \text{Variance Term} = 0.05 \times \text{variance}
663
  \]
664
-
665
  6. **Best Action Sequence Extraction (`best_action_sequence`):**
666
  - **Function:** Extracts the most promising action sequence from the MCTS tree after all iterations.
667
  - **Procedure:**
@@ -671,21 +594,6 @@ The `MCTS` class implements the Monte Carlo Tree Search algorithm tailored to th
671
 
672
  ---
673
 
674
-
675
- ### Mixture of Experts (MoE)
676
-
677
- \[
678
- \text{MoE}(x) = \sum_{i=1}^k g_i(x) \cdot \text{Expert}_i(x)
679
- \]
680
- Where:
681
- - \( g_i(x) \): Gating weights ensuring sparsity (only top-k experts are active).
682
- - \( \text{Expert}_i(x) \): Outputs from the expert networks.
683
- - \( k \): Number of top experts to activate.
684
-
685
- **Explanation:**
686
- - For each input, only the top-k experts (based on gating scores) process the data.
687
- - Reduces computational load while maintaining high capacity.
688
-
689
  ### Beam Search with Multi-Token Prediction
690
 
691
  **Purpose:** Efficiently explores multiple possible token sequences to generate coherent and diverse outputs by predicting multiple tokens at each step.
@@ -693,15 +601,26 @@ Where:
693
  **Procedure:**
694
 
695
  1. **Beam Initialization:**
 
 
696
  - Start with a beam containing the start-of-sequence (BOS) token.
697
  \[
698
  \text{beam} = \left\{ \left( \text{seq} = [\text{BOS}], \text{score} = 0, \text{cum\_entropy} = 0, \text{cum\_variance} = 0 \right) \right\}
699
  \]
700
-
701
  2. **Iterative Expansion:**
702
- - For each iteration up to \( \frac{\text{max\_length}}{n\_tokens\_predict} \):
 
 
 
 
703
  - For each sequence in the beam:
704
- - Predict the next \( n\_tokens\_predict \) tokens.
 
 
 
 
 
705
  - Calculate their probabilities.
706
  - Select top-k token sequences based on cumulative scores.
707
 
@@ -712,6 +631,7 @@ Where:
712
  - Continue until the maximum length is reached or all sequences end with the end-of-sequence (EOS) token.
713
 
714
  **Mathematical Operations:**
 
715
  \[
716
  \text{Score} = \sum_{t=1}^{n} \log P(\text{token}_t | \text{tokens}_{<t})
717
  \]
@@ -721,6 +641,7 @@ Where:
721
  \[
722
  \text{Variance} = \text{Var}(P)
723
  \]
 
724
  Where \( P \) is the probability distribution over the vocabulary.
725
 
726
  ### Upper Confidence Bound (UCB) in MCTS
 
315
  ### Checkpoints
316
  At the end of each epoch, the model saves checkpoints of all components, enabling easy resumption or further fine-tuning as needed.
317
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
318
  ---
319
 
320
  ### World Model Components
 
327
  Transforms the transformer's output embeddings into a compact state representation suitable for modeling and prediction tasks.
328
 
329
  **Mathematical Operation:**
330
+ ```
331
  \[
332
  \text{State} = \text{LayerNorm}\left(\text{Linear}(d_{\text{model}} \rightarrow d_{\text{state}})\left(\text{Linear}(vocab\_dim \rightarrow d_{\text{model}})(\text{Transformer Output})\right)\right)
333
  \]
334
+ ```
335
  **Explanation:**
336
  Sequential linear transformations project high-dimensional embeddings into a lower-dimensional state space, followed by layer normalization for stability.
337
 
 
341
  Models how the state evolves in response to actions (thoughts) taken by the model.
342
 
343
  **Mathematical Operation:**
344
+ ```
345
  \[
346
  \text{Next State} = \text{DynamicsNetwork}(\text{Current State}, \text{Action Embedding})
347
  \]
348
+ ```
349
 
350
  **Explanation:**
351
  Predicts the subsequent state by combining the current state representation with an encoded action, effectively simulating the consequences of actions within the Tree of Thought.
 
356
  Predicts policy logits (action probabilities) and value estimates (state evaluations) based on the current state.
357
 
358
  **Mathematical Operation:**
359
+ ```
360
  \[
361
  (\text{Policy Logits}, \text{Value Estimate}) = \text{PredictionNetwork}(\text{State})
362
  \]
363
+ ```
364
  **Explanation:**
365
  - **Policy Logits:** Used to derive action probabilities via softmax.
366
  - **Value Estimate:** Represents the expected reward or quality of the current state.
 
371
  Encodes discrete actions (thoughts) into continuous embeddings compatible with the DynamicsNetwork.
372
 
373
  **Mathematical Operation:**
374
+ ```
375
  \[
376
  \text{Action Embedding} = \text{ActionEncoder}(\text{Action Index})
377
  \]
378
+ ```
379
  **Explanation:**
380
  Converts action indices into dense vector representations, facilitating their integration into state transition modeling.
381
 
 
399
  **Mathematical Representation:**
400
 
401
  Each `ThoughtNode` can be represented as a tree node in a directed graph:
402
+ ```
403
  \[
404
  \text{ThoughtNode} = (\text{name}, \{\text{children}\})
405
  \]
406
+ ```
407
  #### State
408
 
409
  **Function:**
 
417
  - `thought_node`: Reference to the current `ThoughtNode` in the Tree of Thought.
418
 
419
  **Action Application (`apply_action`):**
420
+ ```
421
  \[
422
  \text{Next State} = \text{DynamicsNetwork}(\text{Current State}, \text{Action Embedding})
423
  \]
424
  \[
425
  \text{New Representation} = \text{Concat}(\text{Current Representation}, \text{Next State} \rightarrow \text{unsqueeze}(1))
426
  \]
427
+ ```
428
  **Procedure:**
429
 
430
  1. **Action Encoding:**
431
+
432
+ ```
433
  \[
434
  \text{Action Index} = \text{Index of Action}
435
  \]
436
  \[
437
  \text{Action Embedding} = \text{ActionEncoder}(\text{Action Index})
438
  \]
439
+ ```
440
  2. **State Extraction:**
441
+
442
+ ```
443
  \[
444
  \text{Current State} = \text{representation}[:, -1, :]
445
  \]
446
+ ```
447
  3. **State Transition:**
448
+
449
+ ```
450
  \[
451
  \text{Next State Representation} = \text{DynamicsNetwork}(\text{Current State}, \text{Action Embedding})
452
  \]
453
+ ```
454
  4. **Representation Update:**
455
+
456
+ ```
457
  \[
458
  \text{New Representation} = \text{Concat}(\text{representation}, \text{Next State Representation} \times \text{unsqueeze}(1))
459
  \]
460
+ ```
461
  5. **Thought Node Update:**
462
  - Navigate to the child `ThoughtNode` corresponding to the applied action.
463
 
 
487
  **Mathematical Representation:**
488
 
489
  Each `MCTSNode` can be considered as:
490
+ ```
491
  \[
492
  \text{MCTSNode} = (\text{state}, \text{parent}, \text{action}, \{\text{children}\}, \text{visit\_count}, \text{value\_sum}, \text{prior}, \text{entropy}, \text{variance})
493
  \]
494
+ ```
495
  #### MCTS Algorithm
496
 
497
  The `MCTS` class implements the Monte Carlo Tree Search algorithm tailored to the LightBulb model's architecture.
 
521
  - Add the candidate sequence to `all_candidates`.
522
  - **Beam Pruning:**
523
  - Sort all candidates based on a combined score:
524
+ ```
525
  \[
526
  \text{Combined Score} = \text{Score} - 0.1 \times \text{Entropy} + 0.05 \times \text{Variance}
527
  \]
528
+ ```
529
  - Retain the top `beam_size` candidates for the next iteration.
530
  4. **Result Extraction:**
531
  - After completing iterations, select the best action sequence from the final beam.
 
539
  - Calculate entropy and variance of the policy distribution.
540
  - Expand the node by creating child nodes based on the Tree of Thought and assign priors from policy probabilities.
541
  - **Mathematical Operations:**
542
+ ```
543
  \[
544
  (\text{Policy Logits}, \text{Value Estimate}) = \text{PredictionNetwork}(\text{State})
545
  \]
 
552
  \[
553
  \text{Variance} = \text{Var}(P)
554
  \]
555
+ ```
556
  4. **Backpropagation (`backpropagate`):**
557
  - **Function:** Updates the `visit_count` and `value_sum` for nodes along the path from the evaluated node back to the root.
558
  - **Procedure:**
559
+ ```
560
  \[
561
  \text{For each node in the path:} \\
562
  \quad \text{node.visit\_count} \mathrel{+}= 1 \\
563
  \quad \text{node.value\_sum} \mathrel{+}= \text{Value Estimate}
564
  \]
565
+ ```
566
  5. **Upper Confidence Bound (UCB) Score (`ucb_score`):**
567
  - **Function:** Balances exploration of less-visited nodes and exploitation of high-value nodes.
568
  - **Mathematical Operation:**
569
+
570
+ ```
571
  \[
572
  \text{UCB Score} = \text{Average Value} + \text{Exploration Term} + \text{Entropy Term} + \text{Variance Term}
573
  \]
 
584
  \[
585
  \text{Variance Term} = 0.05 \times \text{variance}
586
  \]
587
+ ```
588
  6. **Best Action Sequence Extraction (`best_action_sequence`):**
589
  - **Function:** Extracts the most promising action sequence from the MCTS tree after all iterations.
590
  - **Procedure:**
 
594
 
595
  ---
596
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
597
  ### Beam Search with Multi-Token Prediction
598
 
599
  **Purpose:** Efficiently explores multiple possible token sequences to generate coherent and diverse outputs by predicting multiple tokens at each step.
 
601
  **Procedure:**
602
 
603
  1. **Beam Initialization:**
604
+
605
+ ```
606
  - Start with a beam containing the start-of-sequence (BOS) token.
607
  \[
608
  \text{beam} = \left\{ \left( \text{seq} = [\text{BOS}], \text{score} = 0, \text{cum\_entropy} = 0, \text{cum\_variance} = 0 \right) \right\}
609
  \]
610
+ ```
611
  2. **Iterative Expansion:**
612
+ - For each iteration up to
613
+ ```
614
+ \( \frac{\text{max\_length}}{n\_tokens\_predict} \)
615
+ ```
616
+ :
617
  - For each sequence in the beam:
618
+ - Predict the next:
619
+ ```
620
+ \( n\_tokens\_predict \)
621
+
622
+ ```
623
+ tokens.
624
  - Calculate their probabilities.
625
  - Select top-k token sequences based on cumulative scores.
626
 
 
631
  - Continue until the maximum length is reached or all sequences end with the end-of-sequence (EOS) token.
632
 
633
  **Mathematical Operations:**
634
+ ```
635
  \[
636
  \text{Score} = \sum_{t=1}^{n} \log P(\text{token}_t | \text{tokens}_{<t})
637
  \]
 
641
  \[
642
  \text{Variance} = \text{Var}(P)
643
  \]
644
+ ```
645
  Where \( P \) is the probability distribution over the vocabulary.
646
 
647
  ### Upper Confidence Bound (UCB) in MCTS