RobbiePasquale commited on
Commit
5977163
·
verified ·
1 Parent(s): 8793340

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +144 -1057
README.md CHANGED
@@ -1,1147 +1,234 @@
1
  ---
2
  license: apache-2.0
3
 
 
 
 
4
  ---
5
 
6
- ## Overview of the Main Menu
7
-
8
- The `main_menu.py` script is the primary entry point for choosing and executing one of three tasks:
9
- 1. **Training the LLM and World Model**: `train_llm_world`
10
- 2. **Training the Search Agent**: `train_agent`
11
- 3. **Testing the Tree of Thought Search Agent**: `test_agent`
12
-
13
- Each task has unique functionalities and configurations. This script uses command-line arguments to specify the desired task and additional options, giving users the ability to tailor the execution according to their needs.
14
-
15
- ### Running the Main Menu
16
-
17
- To run the main menu, use the following command in the terminal:
18
- ```bash
19
- python main_menu.py --task <task_name> [additional arguments]
20
- ```
21
-
22
- Replace `<task_name>` with one of the following:
23
- - `train_llm_world` - Train the LLM (Language Model) and World Model.
24
- - `train_agent` - Train the Search Agent with an interactive Twisted-based process.
25
- - `test_agent` - Test the Tree of Thought Search Agent, with the option of an interactive session or a single query.
26
-
27
- ### General Arguments
28
-
29
- The script supports a set of command-line arguments to customize each task. Here’s an overview of all possible arguments:
30
 
31
- | Argument | Required | Description | Default |
32
- |------------------------|----------|-------------------------------------------------------------------------------------------------------------------------------------------------------|--------------------------|
33
- | `--task` | Yes | Specifies the task to run. Choose from `train_llm_world`, `train_agent`, or `test_agent`. | None |
34
- | `--model_name` | No | Pretrained model name for LLM. Options include `gpt2`, `bert`, etc., or a custom model path. | `gpt2` |
35
- | `--dataset_name` | No | Name of the dataset from Hugging Face Datasets for training the LLM and World Model (e.g., `wikitext`). | `wikitext` |
36
- | `--dataset_config` | No | Dataset configuration name for specifying different versions or configurations of the dataset. | `wikitext-2-raw-v1` |
37
- | `--batch_size` | No | Number of samples processed in a single forward/backward pass. Increasing the batch size can speed up training but requires more memory. | `4` |
38
- | `--num_epochs` | No | Number of times to iterate over the training dataset during model training. More epochs generally improve learning but can lead to overfitting. | `3` |
39
- | `--max_length` | No | Maximum sequence length for training/inference. Truncates or pads sequences to this length to maintain consistency in training. | `128` |
40
- | `--mode` | No | Specifies the mode for the LLM and World Model. Use `train` for training and `inference` for generating responses. | `train` |
41
- | `--query` | No | Query input for `test_agent` when running a single query instead of an interactive session. | `''` (empty) |
42
 
43
- ## Task Details
44
 
45
- ### 1. Training the LLM and World Model (`train_llm_world`)
46
 
47
- This task trains the LLM and the World Model using a chosen dataset from Hugging Face. Training includes adjusting model weights through epochs and creating a model capable of handling long sequences and complex reasoning tasks.
48
 
49
- #### Example Usage
50
  ```bash
51
- python main_menu.py --task train_llm_world --model_name gpt2 --dataset_name wikitext --num_epochs 5 --batch_size 8 --max_length 256
52
  ```
53
 
54
- #### Arguments Specific to `train_llm_world`
55
- - **`--model_name`**: Name of the pretrained model to use for language model training. You can specify a model name (like `gpt2`, `bert`, etc.) or a path to a custom model. This argument affects the model architecture and tokenization style.
56
-
57
- - **`--dataset_name`**: Specifies the dataset from Hugging Face’s Datasets library to train the model. Options include `wikitext`, `imdb`, `squad`, etc. You can also use a custom dataset by specifying its path.
58
-
59
- - **`--dataset_config`**: Defines the configuration of the dataset, which may be different versions or variations of the dataset. For example, `wikitext` includes configurations such as `wikitext-2-raw-v1`. The configuration will affect the format and content of the data.
60
-
61
- - **`--batch_size`**: The number of samples per batch. A larger batch size requires more memory but can improve training speed. You might need to reduce the batch size if memory is limited.
62
-
63
- - **`--num_epochs`**: The number of complete passes through the training dataset. More epochs can improve the model’s ability to learn but may lead to overfitting if too high.
64
-
65
- - **`--max_length`**: Limits the maximum length of the input sequence. Truncated sequences will be cut off, and shorter sequences will be padded. This affects both training and inference.
66
-
67
- - **`--mode`**: Defines the task to be performed. Choose `train` to start training the model. If set to `inference`, the model generates text based on the input.
68
-
69
- ### 2. Training the Search Agent (`train_agent`)
70
- Here's a detailed breakdown of your search agent, covering training, inference, and the functionality of each component. This overview will also highlight how the agent saves LLM training data, its modular structure, and the role of each module.
71
-
72
- ---
73
-
74
- ## Overview of the AutonomousWebAgent
75
-
76
- The `AutonomousWebAgent` is a sophisticated, multi-component search and retrieval agent designed to navigate the web, gather relevant content, and perform summarization and generation based on user queries. This agent integrates reinforcement learning (RL), Monte Carlo Tree Search (MCTS), a Retrieva-Augmented Generation (RAG) Summarizer, and a Hierarchical Reinforcement Learning (HRL) architecture to select, execute, and optimize its actions based on past experiences.
77
-
78
- ### Key Components
79
-
80
- 1. **Prioritized Experience Replay**:
81
- - The agent uses a `PrioritizedReplayMemory` and a `SumTree` to prioritize and store experiences (transitions between states).
82
- - The `SumTree` structure maintains a binary tree where each parent node's value is the sum of its children, helping to efficiently store, update, and retrieve experiences based on priority.
83
- - These experiences are critical in training both high-level (manager) and low-level (worker) components through prioritized sampling during replay, allowing the model to focus on more significant transitions.
84
 
85
- 2. **Hierarchical Reinforcement Learning (HRL)**:
86
- - HRL is employed to allow a **Manager** (high-level) model to select options, which are then executed by a **Worker** (low-level) model. The `ManagerModel` selects tasks (such as searching, summarizing, or generating), while the `WorkerModel` determines specific actions to take.
87
- - The manager and worker use LSTM networks with fully connected layers, and each has its own replay memory and optimization process.
88
- - The Manager focuses on broad decisions and options, while the Worker operates on specific actions, enabling a layered approach to decision-making.
89
 
90
- 3. **RAGSummarizer**:
91
- - The `RAGSummarizer` leverages a pre-trained language model (e.g., GPT-2) for summarizing, and a SentenceTransformer for embedding-based retrieval. This module breaks down the input content into chunks, retrieves relevant sections based on cosine similarity with the query, and generates a coherent summary.
92
- - Additionally, it implements a Least Recently Used (LRU) cache to avoid redundant computation and enhance efficiency, along with persistent storage for cache data.
93
- - Summarized results are stored, and this module contributes directly to the generation of LLM training data.
94
 
95
- 4. **WorldModel**:
96
- - This module encapsulates an LSTM architecture with linear layers and a `value_head` to estimate state values, allowing the agent to anticipate the long-term value of its actions.
97
- - It is utilized in the HRL architecture, specifically by the Worker for evaluating actions and by the Manager in long-term decision-making.
98
 
99
- 5. **Knowledge Base**:
100
- - The knowledge base acts as a repository for collected data, maintaining embeddings for efficient search and retrieval.
101
- - It supports saving and loading document embeddings, so the agent can retrieve relevant information for new queries from previously collected knowledge.
102
- - Adding and retrieving from the knowledge base enriches the agent’s context and allows it to store and use information from past experiences to inform current tasks.
103
 
104
- 6. **Monte Carlo Tree Search (MCTS)**:
105
- - The MCTS component guides the agent through complex decision trees to determine the most promising paths for query refinement.
106
- - Nodes in the tree represent states (possible query refinements), and child nodes represent possible expansions (e.g., related query variations).
107
- - MCTS utilizes a `select`, `expand`, `simulate`, and `backpropagate` strategy to iteratively refine queries, scoring them based on relevance and other metrics to converge on optimal searches.
108
- - It also integrates RL by backpropagating rewards based on the ranking score from retrieved results.
109
-
110
- 7. **Ranking Model**:
111
- - The ranking model, built with a neural network and the `SentenceTransformer`, ranks search results based on various features such as cosine similarity with the query, content length, keyword overlap, and domain authority.
112
- - This model assigns scores to results, which are then used to guide the MCTS process by enhancing the combined reward with ranking scores.
113
-
114
- 8. **Tree of Thought (ToT) Search**:
115
- - This module enhances the agent's capability to generate a series of interconnected thoughts, exploring different perspectives or angles on a given query.
116
- - `ToTNode` and `ToTSearch` classes enable the agent to generate thoughts, evaluate them, and navigate through them as a tree, considering various potential paths to best answer the query.
117
- - It combines MCTS and RAG to synthesize responses based on the generated thought paths.
118
-
119
-
120
- ### Training Process
121
-
122
- The training process for the agent involves episodic learning, where it interacts with various queries from a predefined list. Each query initiates an episode, and the agent performs actions based on its learned policy:
123
-
124
- 1. **Search and Summarization**:
125
- - The agent performs search operations, gathering relevant content from online sources using the MCTS and Ranking Model for prioritization.
126
- - Summarization is then carried out on the retrieved content, with relevant information stored in the LLM training data.
127
-
128
- 2. **Knowledge Base and LLM Training Data Storage**:
129
- - Throughout the training process, the agent stores retrieved documents, query results, and summaries in its knowledge base and saves training data for future LLM fine-tuning.
130
- - The data is saved in JSONL format and includes metadata such as query terms, source links, and summaries, making it valuable for training language models.
131
-
132
- 3. **Experience Replay**:
133
- - Both the manager and worker models engage in prioritized experience replay, sampling from the stored experiences in the SumTree based on TD-errors.
134
- - Replay is essential for reinforcing successful transitions and updating the models' policies over time.
135
-
136
- 4. **Reward Calculation and Backpropagation**:
137
- - Rewards are calculated based on ranking scores, cosine similarity with the query, and other custom factors (e.g., query complexity, state length).
138
- - These rewards are backpropagated through the MCTS and used to update the models' decision-making processes, ensuring continuous learning and adaptation.
139
-
140
- ### Inference Process
141
-
142
- During inference:
143
- - The agent accepts a query, and the Manager model selects a high-level action based on its policy (e.g., search, summarize, or generate).
144
- - Once an option is chosen, the Worker model executes the corresponding low-level actions. For example, in a search operation, it leverages MCTS to refine the query, retrieves relevant web content, and processes it with the RAGSummarizer.
145
- - Each inference step is augmented by the agent's existing knowledge base, enabling it to produce more informed and contextually rich responses. Additionally, if Tree of Thought (ToT) is employed, the agent synthesizes a coherent and comprehensive answer based on the thought path.
146
 
147
- ### Model Saving
148
 
149
- The agent incorporates a series of save functions to preserve the models:
150
- - `save_worker_model` and `save_manager_model` functions save the worker and manager models independently.
151
- - The `save` method preserves the overall state of the agent, which includes its knowledge base, replay memories, and models. This facilitates model reusability and persistent storage, enabling the agent to resume from saved states during training or deployment.
152
 
153
- ---
154
 
155
- This modular setup enhances flexibility, allowing the agent to dynamically adjust its behavior based on rewards from RL, improvements from experience replay, and efficient decision-making through MCTS. Additionally, by saving LLM training data, it becomes highly reusable for further fine-tuning, offering the opportunity to build specialized, data-driven language models optimized for specific domains or tasks.
156
- This task uses Twisted to train the Autonomous Web Agent by interacting with various queries in a simulated or real environment. It collects rewards based on how well the agent navigates and summarizes web content or performs other tasks.
157
 
158
- #### Example Usage
159
  ```bash
160
  python main_menu.py --task train_agent
161
  ```
162
 
163
- #### Process Details
164
- - **Training**: During training, the agent will automatically sample a list of predefined queries, explore web pages, and use reinforcement learning to maximize its reward based on its actions. The training log provides insights into each episode's reward and the agent’s progress.
165
-
166
- - **Logging**: Logs are recorded to `agent_training.log` and provide information about each episode, such as the query, the total reward, and the episode duration. Errors are logged, and if an episode times out, a negative reward is given.
 
167
 
168
- ### 3. Testing the Tree of Thought Search Agent (`test_agent`)
169
 
170
- This task lets you test the Tree of Thought Search Agent either in an interactive mode or by specifying a single query. In interactive mode, the user can repeatedly enter queries, and the agent will process them sequentially, producing responses based on the Tree of Thought architecture.
 
171
 
172
- #### Example Usage
173
- Interactive Mode:
174
  ```bash
175
  python main_menu.py --task test_agent
176
  ```
177
 
178
- Single Query Mode:
179
- ```bash
180
- python main_menu.py --task test_agent --query "What are the impacts of renewable energy on global sustainability?"
181
- ```
182
-
183
- #### Arguments Specific to `test_agent`
184
- - **`--query`**: If provided, the agent will process this specific query and return a response. This is ideal for quick, one-off tests or evaluations. If not provided, the program will start an interactive session where you can repeatedly input queries and view the agent's response.
185
-
186
- #### Interactive Mode Details
187
- - **Input**: In interactive mode, enter a query and press Enter. The agent will respond based on its training and the Tree of Thought methodology, traversing different thought paths to generate a response.
188
-
189
- - **Exiting**: To exit the interactive session, type `quit` and press Enter. The agent will then save any new knowledge it has gained and exit the program.
190
-
191
-
192
- -------------------------------------------------------------------------------------------------------------------------------------------------
193
- ## World Model
194
-
195
- 2. **Representation Network**: This module encodes the Transformer output to generate a state representation, reducing dimensionality and making it suitable for further processing.
196
- 3. **Dynamics Network**: This module predicts the next state given a current state and an action. It uses layer normalization and a GELU activation function.
197
- 4. **Prediction Network**: Predicts both the policy logits and value estimates for a given state. It outputs the probabilities of different actions as well as a single scalar value.
198
- 5. **MCTS**: This module performs Monte Carlo Tree Search to evaluate the quality of actions over multiple iterations. It expands nodes based on the policy logits from the Prediction Network and simulates the reward by backpropagating value estimates.
199
- 6. **PPO Agent**: Uses policy and value estimates to calculate PPO loss, which updates the policy while maintaining the constraint on the KL divergence between old and new policies.
200
-
201
- The transformer strategically utilises beam search as well as multi token prediction, in order to enrich the encoding from the representation network.
202
-
203
- A generated sequence of tokens is an action, for example if a token is t, then an action is:
204
-
205
- a_1= {t1,...,tN}
206
-
207
- then a policy is a sequence of actions:
208
 
209
- P_1 = {a_1,...,aN}
210
 
211
- The MCTS and OOPS explores what we are defining as 'thoughts', where a thought is a set of policies:
 
212
 
213
- thought_1 = {P1, ... , PN}
214
-
215
- The model explores and exploits thoughts, policies, actions, and tokens, and learning happens at each step of granularity.
216
-
217
-
218
- ## Training Details
219
-
220
- The model is trained with the following components and techniques:
221
-
222
- ### Training Procedure
223
- - **Data Loading**: The data is tokenized and prepared with attention to padding and truncation. Text data is grouped into sequences of fixed length for efficient training.
224
- - **Optimization**: Training uses an **AdamW** optimizer with **CosineAnnealingLR** scheduler for learning rate adjustments. The **Gradient Scaler** helps prevent overflow when training with mixed precision.
225
- - **Gradient Accumulation**: Since the model can be computationally heavy, gradients are accumulated over several steps to reduce memory usage.
226
- - **Loss Functions**: The training process leverages a comprehensive set of custom loss functions:
227
-
228
- **1. InfoNCE Loss (Info Noise Contrastive Estimation Loss):**
229
- Definition: This loss function is used for contrastive learning, encouraging similar samples to be close in the embedding space while pushing dissimilar samples apart.
230
-
231
- Formula:
232
- ```
233
- L_InfoNCE = -log[ exp(sim(z_i, z_j) / τ) / Σ_k exp(sim(z_i, z_k) / τ) ]
234
- ```
235
- where sim() is the cosine similarity, τ is the temperature parameter, z_i and z_j are paired samples, and the sum in the denominator is over all other samples in the batch.
236
-
237
- **2. Covariance Regularization:**
238
- Definition: This regularization term encourages the learned representations to have uncorrelated dimensions, promoting more diverse and informative embeddings.
239
-
240
- Formula:
241
- ```
242
- L_cov = λ * (Σ_i,j (Cov(i,j)^2 - diag(Cov(i,j))^2))
243
- ```
244
- where Cov is the covariance matrix of the embeddings, and λ is a regularization coefficient.
245
-
246
- **3. Dynamics Performance Loss:**
247
- Definition: This loss measures the accuracy of predicted next states while also encouraging diverse predictions.
248
-
249
- Formula:
250
- ```
251
- L_dynamics = MSE(true_next_state, predicted_next_state) + λ * Var(predicted_next_state)
252
- ```
253
- where MSE is the mean squared error, Var is the variance, and λ is a weighting factor.
254
-
255
- **4. Thought Consistency Loss:**
256
- Definition: This loss encourages consistency between true next states and perturbed next states.
257
-
258
- Formula:
259
- ```
260
- L_consistency = MSE(true_next_state, perturbed_next_state)
261
- ```
262
-
263
- **5. Policy Value Joint Loss:**
264
- Definition: This loss combines policy and value losses for reinforcement learning tasks.
265
-
266
- Formula:
267
- ```
268
- L_joint = CrossEntropy(policy_logits, true_policy) + λ * MSE(value_pred, true_value)
269
  ```
270
- where λ is a weighting factor balancing policy and value losses.
271
 
272
- **6. Action Diversity Reward:**
273
- Definition: This reward encourages diversity in action embeddings.
 
 
 
 
274
 
275
- Formula:
276
- ```
277
- R_diversity = λ * Σ_i,j (cos_sim(a_i, a_j)^2)
278
- ```
279
- where cos_sim is the cosine similarity between action embeddings, and λ is a scaling factor.
280
 
281
- **7. Expected Thought Value Loss:**
282
- Definition: This loss aims to maximize the expected value from Monte Carlo Tree Search.
283
 
284
- Formula:
285
- ```
286
- L_ETV = -mean(mcts_best_values)
287
  ```
288
 
289
- **8. Exploration Regularization:**
290
- Definition: This regularization encourages exploration by rewarding less-visited actions.
 
 
291
 
292
- Formula:
293
- ```
294
- R_exploration = λ * mean(Σ_a (1 / (visit_count(a) + 1)))
295
- ```
296
- where λ is a scaling factor.
297
 
298
- **9. KL Divergence Loss:**
299
- Definition: This loss measures the difference between old and new policies in policy optimization.
300
 
301
- Formula:
302
- ```
303
- L_KL = KL(new_policy || old_policy) = \sum_{i=1}^{n} old\_policy_i \cdot \log\left(\frac{old\_policy_i}{new\_policy_i}\right)
304
  ```
305
- where KL is the Kullback-Leibler divergence.
306
-
307
- L_KL is the KL divergence loss
308
- old_policy and new_policy are probability distributions
309
- i represents each possible outcome or action
310
- n is the total number of possible outcomes or actions
311
-
312
- ### Evaluation
313
- After each epoch, the model is evaluated on the validation set, computing the average loss over the dataset. The evaluation function utilizes the same loss functions as training but does not backpropagate, allowing it to be run in inference mode.
314
 
315
- ### Checkpoints
316
- At the end of each epoch, the model saves checkpoints of all components, enabling easy resumption or further fine-tuning as needed.
 
 
 
317
 
318
- ---
319
-
320
- ### World Model Components
321
-
322
- The World Model encapsulates components that model state representations, dynamics, predictions, and action encodings. These components interact with the Transformer to simulate and predict state transitions within the Tree of Thought framework.
323
-
324
- #### RepresentationNetwork
325
-
326
- **Function:**
327
- Transforms the transformer's output embeddings into a compact state representation suitable for modeling and prediction tasks.
328
-
329
- **Mathematical Operation:**
330
- ```
331
- \[
332
- \text{State} = \text{LayerNorm}\left(\text{Linear}(d_{\text{model}} \rightarrow d_{\text{state}})\left(\text{Linear}(vocab\_dim \rightarrow d_{\text{model}})(\text{Transformer Output})\right)\right)
333
- \]
334
- ```
335
- **Explanation:**
336
- Sequential linear transformations project high-dimensional embeddings into a lower-dimensional state space, followed by layer normalization for stability.
337
-
338
- #### DynamicsNetwork
339
 
340
- **Function:**
341
- Models how the state evolves in response to actions (thoughts) taken by the model.
342
 
343
- **Mathematical Operation:**
344
- ```
345
- \[
346
- \text{Next State} = \text{DynamicsNetwork}(\text{Current State}, \text{Action Embedding})
347
- \]
348
  ```
349
 
350
- **Explanation:**
351
- Predicts the subsequent state by combining the current state representation with an encoded action, effectively simulating the consequences of actions within the Tree of Thought.
 
 
352
 
353
- #### PredictionNetwork
354
 
355
- **Function:**
356
- Predicts policy logits (action probabilities) and value estimates (state evaluations) based on the current state.
357
 
358
- **Mathematical Operation:**
359
- ```
360
- \[
361
- (\text{Policy Logits}, \text{Value Estimate}) = \text{PredictionNetwork}(\text{State})
362
- \]
363
- ```
364
- **Explanation:**
365
- - **Policy Logits:** Used to derive action probabilities via softmax.
366
- - **Value Estimate:** Represents the expected reward or quality of the current state.
367
-
368
- #### ActionEncoder
369
-
370
- **Function:**
371
- Encodes discrete actions (thoughts) into continuous embeddings compatible with the DynamicsNetwork.
372
-
373
- **Mathematical Operation:**
374
- ```
375
- \[
376
- \text{Action Embedding} = \text{ActionEncoder}(\text{Action Index})
377
- \]
378
  ```
379
- **Explanation:**
380
- Converts action indices into dense vector representations, facilitating their integration into state transition modeling.
381
-
382
- ---
383
-
384
- ### Tree of Thought (ToT)
385
-
386
- The Tree of Thought provides a structured representation of possible thoughts/actions the model can take, organized hierarchically to enable efficient exploration during reasoning.
387
-
388
- #### ThoughtNode
389
-
390
- **Function:**
391
- Represents a node in the Tree of Thought, corresponding to a specific action or thought.
392
 
393
- **Attributes:**
 
 
 
 
394
 
395
- - `name`: Identifier for the thought/action.
396
- - `children`: List of child `ThoughtNode` instances representing possible subsequent thoughts/actions.
397
- - `parent`: Reference to the parent `ThoughtNode`.
398
 
399
- **Mathematical Representation:**
 
 
 
 
 
 
 
 
 
 
400
 
401
- Each `ThoughtNode` can be represented as a tree node in a directed graph:
402
- ```
403
- \[
404
- \text{ThoughtNode} = (\text{name}, \{\text{children}\})
405
- \]
406
- ```
407
- #### State
408
 
409
- **Function:**
410
- Represents the current state within the MCTS and Tree of Thought framework.
 
 
 
 
 
411
 
412
- **Attributes:**
413
 
414
- - `representation`: Tensor capturing the current state, shaped as \((\text{batch\_size}, \text{seq\_len}, \text{state\_dim})\).
415
- - `dynamics_network`: Reference to the `DynamicsNetwork` for state transitions.
416
- - `action_encoder`: Reference to the `ActionEncoder` for encoding actions.
417
- - `thought_node`: Reference to the current `ThoughtNode` in the Tree of Thought.
418
 
419
- **Action Application (`apply_action`):**
420
- ```
421
- \[
422
- \text{Next State} = \text{DynamicsNetwork}(\text{Current State}, \text{Action Embedding})
423
- \]
424
- \[
425
- \text{New Representation} = \text{Concat}(\text{Current Representation}, \text{Next State} \rightarrow \text{unsqueeze}(1))
426
- \]
427
  ```
428
- **Procedure:**
429
 
430
- 1. **Action Encoding:**
431
-
432
- ```
433
- \[
434
- \text{Action Index} = \text{Index of Action}
435
- \]
436
- \[
437
- \text{Action Embedding} = \text{ActionEncoder}(\text{Action Index})
438
- \]
439
- ```
440
- 2. **State Extraction:**
441
 
 
 
442
  ```
443
- \[
444
- \text{Current State} = \text{representation}[:, -1, :]
445
- \]
446
- ```
447
- 3. **State Transition:**
448
 
449
- ```
450
- \[
451
- \text{Next State Representation} = \text{DynamicsNetwork}(\text{Current State}, \text{Action Embedding})
452
- \]
453
- ```
454
- 4. **Representation Update:**
455
 
 
 
456
  ```
457
- \[
458
- \text{New Representation} = \text{Concat}(\text{representation}, \text{Next State Representation} \times \text{unsqueeze}(1))
459
- \]
460
- ```
461
- 5. **Thought Node Update:**
462
- - Navigate to the child `ThoughtNode` corresponding to the applied action.
463
-
464
- ---
465
-
466
- ### Monte Carlo Tree Search (MCTS)
467
-
468
- MCTS is an algorithm used to make optimal decisions by traversing the Tree of Thought, balancing exploration of new actions and exploitation of known rewarding actions using statistical methods.
469
 
470
- #### MCTSNode
471
-
472
- **Function:**
473
- Represents a node in the MCTS search tree, encapsulating a specific state in the search process.
474
-
475
- **Attributes:**
476
-
477
- - `state`: Current state represented by a `State` instance.
478
- - `parent`: Reference to the parent `MCTSNode`.
479
- - `action`: Action taken to reach this node.
480
- - `children`: Dictionary mapping actions to child `MCTSNode` instances.
481
- - `visit_count`: Number of times this node has been visited.
482
- - `value_sum`: Cumulative value obtained from simulations passing through this node.
483
- - `prior`: Prior probability of selecting this action, derived from policy logits.
484
- - `entropy`: Entropy of the policy distribution at this node.
485
- - `variance`: Variance of the policy distribution at this node.
486
-
487
- **Mathematical Representation:**
488
-
489
- Each `MCTSNode` can be considered as:
490
- ```
491
- \[
492
- \text{MCTSNode} = (\text{state}, \text{parent}, \text{action}, \{\text{children}\}, \text{visit\_count}, \text{value\_sum}, \text{prior}, \text{entropy}, \text{variance})
493
- \]
494
- ```
495
- #### MCTS Algorithm
496
-
497
- The `MCTS` class implements the Monte Carlo Tree Search algorithm tailored to the LightBulb model's architecture.
498
-
499
- **Key Steps:**
500
-
501
- 1. **Initialization:**
502
- - Initialize with the `PredictionNetwork`, `DynamicsNetwork`, `ActionEncoder`, number of iterations (`num_iterations`), exploration constant (`exploration_constant`), beam size (`beam_size`), and number of tokens to predict (`n_tokens_predict`).
503
-
504
- 2. **Search with Beam (`search_with_beam`):**
505
- - **Objective:** Explore the Tree of Thought using beam search augmented with MCTS principles.
506
- - **Procedure:**
507
- 1. **Root Node Evaluation and Backpropagation:**
508
- - Evaluate the root node to obtain policy logits and value estimates.
509
- - Backpropagate the value estimate to update visit counts and value sums.
510
- 2. **Beam Initialization:**
511
- - Start with a beam containing the root node.
512
- 3. **Iterative Expansion:**
513
- - For each iteration up to `num_iterations`:
514
- - **Candidate Collection:**
515
- - For each node in the current beam:
516
- - If it's a leaf node, evaluate and backpropagate its value.
517
- - If it has children, select top `beam_size` actions based on UCB scores.
518
- - For each selected action:
519
- - Predict a sequence of `n_tokens_predict` actions.
520
- - Accumulate scores, entropy, and variance.
521
- - Add the candidate sequence to `all_candidates`.
522
- - **Beam Pruning:**
523
- - Sort all candidates based on a combined score:
524
- ```
525
- \[
526
- \text{Combined Score} = \text{Score} - 0.1 \times \text{Entropy} + 0.05 \times \text{Variance}
527
- \]
528
- ```
529
- - Retain the top `beam_size` candidates for the next iteration.
530
- 4. **Result Extraction:**
531
- - After completing iterations, select the best action sequence from the final beam.
532
-
533
- 3. **Evaluation (`evaluate`):**
534
- - **Function:** Computes the policy logits and value estimates for a given node's state.
535
- - **Procedure:**
536
- - Extract the last time step's state representation.
537
- - Pass it through the `PredictionNetwork` to obtain policy logits and a value estimate.
538
- - Convert logits to probabilities using softmax.
539
- - Calculate entropy and variance of the policy distribution.
540
- - Expand the node by creating child nodes based on the Tree of Thought and assign priors from policy probabilities.
541
- - **Mathematical Operations:**
542
- ```
543
- \[
544
- (\text{Policy Logits}, \text{Value Estimate}) = \text{PredictionNetwork}(\text{State})
545
- \]
546
- \[
547
- P = \text{softmax}(\text{Policy Logits})
548
- \]
549
- \[
550
- \text{Entropy} = -\sum_{i=1}^{V} P_i \log P_i
551
- \]
552
- \[
553
- \text{Variance} = \text{Var}(P)
554
- \]
555
- ```
556
- 4. **Backpropagation (`backpropagate`):**
557
- - **Function:** Updates the `visit_count` and `value_sum` for nodes along the path from the evaluated node back to the root.
558
- - **Procedure:**
559
- ```
560
- \[
561
- \text{For each node in the path:} \\
562
- \quad \text{node.visit\_count} \mathrel{+}= 1 \\
563
- \quad \text{node.value\_sum} \mathrel{+}= \text{Value Estimate}
564
- \]
565
- ```
566
- 5. **Upper Confidence Bound (UCB) Score (`ucb_score`):**
567
- - **Function:** Balances exploration of less-visited nodes and exploitation of high-value nodes.
568
- - **Mathematical Operation:**
569
 
 
 
570
  ```
571
- \[
572
- \text{UCB Score} = \text{Average Value} + \text{Exploration Term} + \text{Entropy Term} + \text{Variance Term}
573
- \]
574
- Where:
575
- \[
576
- \text{Average Value} = \frac{\text{value\_sum}}{\text{visit\_count}}
577
- \]
578
- \[
579
- \text{Exploration Term} = \text{exploration\_constant} \times \text{prior} \times \frac{\sqrt{\text{total\_visits}}}{1 + \text{visit\_count}}
580
- \]
581
- \[
582
- \text{Entropy Term} = -0.1 \times \text{entropy}
583
- \]
584
- \[
585
- \text{Variance Term} = 0.05 \times \text{variance}
586
- \]
587
- ```
588
- 6. **Best Action Sequence Extraction (`best_action_sequence`):**
589
- - **Function:** Extracts the most promising action sequence from the MCTS tree after all iterations.
590
- - **Procedure:**
591
- - Traverse all possible sequences in the tree.
592
- - Score each sequence based on cumulative visit counts, entropy, and variance.
593
- - Select the top `beam_size` sequences and return the best one.
594
-
595
- ---
596
 
597
- ### Beam Search with Multi-Token Prediction
598
 
599
- **Purpose:** Efficiently explores multiple possible token sequences to generate coherent and diverse outputs by predicting multiple tokens at each step.
600
-
601
- **Procedure:**
602
-
603
- 1. **Beam Initialization:**
604
-
605
- ```
606
- - Start with a beam containing the start-of-sequence (BOS) token.
607
- \[
608
- \text{beam} = \left\{ \left( \text{seq} = [\text{BOS}], \text{score} = 0, \text{cum\_entropy} = 0, \text{cum\_variance} = 0 \right) \right\}
609
- \]
610
- ```
611
- 2. **Iterative Expansion:**
612
- - For each iteration up to
613
- ```
614
- \( \frac{\text{max\_length}}{n\_tokens\_predict} \)
615
- ```
616
- :
617
- - For each sequence in the beam:
618
- - Predict the next:
619
  ```
620
- \( n\_tokens\_predict \)
621
 
622
- ```
623
- tokens.
624
- - Calculate their probabilities.
625
- - Select top-k token sequences based on cumulative scores.
626
-
627
- 3. **Beam Pruning:**
628
- - After expanding all sequences, retain only the top `beam_size` candidates based on combined scores incorporating log probabilities, entropy, and variance.
629
 
630
- 4. **Termination:**
631
- - Continue until the maximum length is reached or all sequences end with the end-of-sequence (EOS) token.
632
 
633
- **Mathematical Operations:**
634
  ```
635
- \[
636
- \text{Score} = \sum_{t=1}^{n} \log P(\text{token}_t | \text{tokens}_{<t})
637
- \]
638
- \[
639
- \text{Entropy} = -\sum P_i \log P_i
640
- \]
641
- \[
642
- \text{Variance} = \text{Var}(P)
643
- \]
644
  ```
645
- Where \( P \) is the probability distribution over the vocabulary.
646
-
647
- ### Upper Confidence Bound (UCB) in MCTS
648
-
649
- **Purpose:** Balances exploration of less-visited nodes and exploitation of high-value nodes during tree traversal.
650
-
651
- **Mathematical Formulation:**
652
- \[
653
- \text{UCB Score} = \text{Average Value} + \text{Exploration Term} + \text{Entropy Term} + \text{Variance Term}
654
- \]
655
- Where:
656
- \[
657
- \text{Average Value} = \frac{\text{value\_sum}}{\text{visit\_count}}
658
- \]
659
- \[
660
- \text{Exploration Term} = \text{exploration\_constant} \times \text{prior} \times \frac{\sqrt{\text{total\_visits}}}{1 + \text{visit\_count}}
661
- \]
662
- \[
663
- \text{Entropy Term} = -0.1 \times \text{entropy}
664
- \]
665
- \[
666
- \text{Variance Term} = 0.05 \times \text{variance}
667
- \]
668
-
669
- **Explanation:**
670
- - **Average Value:** Encourages exploitation of nodes with high expected rewards.
671
- - **Exploration Term:** Encourages visiting nodes with higher uncertainty (low visit counts).
672
- - **Entropy and Variance Terms:** Modulate preferences based on the policy distribution's entropy and variance, promoting diverse and balanced exploration.
673
-
674
- ### Entropy and Variance Calculations
675
-
676
- **Purpose:** Measure the uncertainty and diversity of the policy distribution, influencing the exploration-exploitation balance.
677
-
678
- **Mathematical Formulation:**
679
- \[
680
- \text{Entropy} = -\sum_{i=1}^{V} P_i \log P_i
681
- \]
682
- \[
683
- \text{Variance} = \frac{1}{V} \sum_{i=1}^{V} (P_i - \mu)^2
684
- \]
685
- Where:
686
- - \( P_i \): Probability of action \( i \).
687
- - \( \mu = \frac{1}{V} \sum_{i=1}^{V} P_i \): Mean probability.
688
-
689
- **Explanation:**
690
- - **Entropy:** Quantifies the unpredictability of the distribution. High entropy indicates a more uniform distribution, promoting exploration.
691
- - **Variance:** Measures the spread of the probabilities. High variance can indicate diverse preferences among actions.
692
-
693
- ---
694
-
695
- ## Inference Workflow
696
-
697
- The inference process in the LightBulb model can operate in two distinct modes:
698
-
699
- 1. **Without World Model:** Utilizes the Transformer with beam search to generate text directly.
700
- 2. **With World Model and Tree of Thought:** Employs the World Model components alongside MCTS and ToT for generating a sequence of thoughts/actions.
701
-
702
- ### Inference Modes
703
-
704
- #### Language Model
705
-
706
- **Procedure:**
707
-
708
- 1. **Input Processing:**
709
- - Tokenize the input query.
710
- - Encode tokens into embeddings via the Transformer.
711
-
712
- 2. **Beam Search Generation:**
713
- - Use the Transformer's `generate_with_beam_search` method.
714
- - Predict multiple tokens at each step (`n_tokens_predict`).
715
- - Maintain a beam of top `beam_size` sequences based on cumulative scores.
716
-
717
- 3. **Output Decoding:**
718
- - Select the best sequence based on scores.
719
- - Decode token IDs back into human-readable text.
720
-
721
- **Mathematical Operations:**
722
-
723
- - **Beam Search Score:**
724
- \[
725
- \text{Score} = \sum_{t=1}^{n} \log P(\text{token}_t | \text{tokens}_{<t})
726
- \]
727
-
728
- - **Entropy and Variance:**
729
- \[
730
- \text{Entropy} = -\sum P_i \log P_i
731
- \]
732
- \[
733
- \text{Variance} = \text{Var}(P)
734
- \]
735
-
736
- #### World Model and Tree of Thought
737
-
738
- **Procedure:**
739
-
740
- 1. **Input Processing:**
741
- - Tokenize the input query.
742
- - Encode tokens into embeddings via the Transformer.
743
- - Transform the output embeddings into a state representation using the `RepresentationNetwork`.
744
-
745
- 2. **MCTS Initialization:**
746
- - Create a root `MCTSNode` with the initial state and the root `ThoughtNode` from ToT.
747
- - Evaluate the root node to obtain policy logits and value estimates.
748
- - Backpropagate the evaluation to update visit counts and value sums.
749
-
750
- 3. **MCTS Iterations with Beam Search:**
751
- - For a predefined number of iterations (`num_iterations`):
752
- - **Beam Expansion:**
753
- - For each node in the current beam:
754
- - If it's a leaf node, evaluate and backpropagate its value.
755
- - Select top `beam_size` actions based on UCB scores.
756
- - For each selected action:
757
- - Predict a sequence of `n_tokens_predict` actions.
758
- - Accumulate scores, entropy, and variance.
759
- - Add the candidate sequence to `all_candidates`.
760
- - **Beam Pruning:**
761
- - Sort all candidates based on a combined score:
762
- \[
763
- \text{Combined Score} = \text{Score} - 0.1 \times \text{Entropy} + 0.05 \times \text{Variance}
764
- \]
765
- - Retain the top `beam_size` candidates for the next iteration.
766
-
767
- 4. **Output Generation:**
768
- - After completing iterations, select the best action sequence from the final beam.
769
- - Return the sequence of actions (thoughts) as the output.
770
 
771
- **Mathematical Operations:**
772
 
773
- - **UCB Score:**
774
- \[
775
- \text{UCB Score} = \frac{\text{value\_sum}}{\text{visit\_count}} + \text{exploration\_constant} \times \text{prior} \times \frac{\sqrt{\text{total\_visits}}}{1 + \text{visit\_count}} - 0.1 \times \text{entropy} + 0.05 \times \text{variance}
776
- \]
777
 
778
- - **Beam Score Adjustment:**
779
- \[
780
- \text{Combined Score} = \text{Score} - 0.1 \times \text{Entropy} + 0.05 \times \text{Variance}
781
- \]
782
 
783
  ---
784
 
785
- ## Inference Execution
786
-
787
- ### Mode: Without World Model
788
-
789
- **Step 1: Input Tokenization and Encoding**
790
-
791
- 1. **Tokenization:**
792
- \[
793
- \text{input\_ids} = \text{tokenizer.encode(query, return\_tensors='pt')}
794
- \]
795
- - Shape: \((\text{batch\_size}=1, \text{seq\_len})\)
796
-
797
- 2. **Encoding via Transformer:**
798
- \[
799
- \text{transformer\_output} = \text{model\_transformer}(input\_ids, input\_ids)
800
- \]
801
- - Shape: \((\text{batch\_size}=1, \text{seq\_len}, \text{d\_model})\)
802
-
803
- **Step 2: Beam Search Generation**
804
-
805
- 1. **Beam Search Initialization:**
806
- \[
807
- \text{beam} = \left\{ \left( \text{seq} = [\text{BOS}], \text{score} = 0, \text{cum\_entropy} = 0, \text{cum\_variance} = 0 \right) \right\}
808
- \]
809
-
810
- 2. **Iterative Expansion:**
811
- - For each iteration up to \(\frac{\text{max\_length}}{n\_tokens\_predict}\):
812
- - For each sequence in the beam:
813
- - If the last token is EOS, retain the sequence.
814
- - Else, predict the next \( n\_tokens\_predict \) tokens.
815
- - Calculate probabilities, entropy, and variance.
816
- - Select top-k tokens for each position based on beam size.
817
- - Generate all possible continuations.
818
- - **Score Calculation:**
819
- \[
820
- \text{new\_score} = \text{score} + \sum_{t=1}^{n} \log P(\text{token}_t | \text{tokens}_{<t})
821
- \]
822
- - **Entropy and Variance Accumulation:**
823
- \[
824
- \text{new\_entropy} = \text{cum\_entropy} + \sum_{t=1}^{n} \text{Entropy}_t
825
- \]
826
- \[
827
- \text{new\_variance} = \text{cum\_variance} + \sum_{t=1}^{n} \text{Variance}_t
828
- \]
829
- - **Candidate Aggregation:**
830
- - Append new sequences to `all_candidates`.
831
-
832
- 3. **Beam Pruning:**
833
- - Sort `all_candidates` based on combined scores.
834
- - Retain the top `beam_size` candidates for the next iteration.
835
-
836
- **Step 3: Output Decoding**
837
-
838
- 1. **Select Best Sequence:**
839
- \[
840
- \text{best\_sequence} = \text{beam}[0][0]
841
- \]
842
-
843
- 2. **Decode Tokens to Text:**
844
- \[
845
- \text{generated\_text} = \text{tokenizer.decode(best\_sequence, skip\_special\_tokens=True)}
846
- \]
847
-
848
- ---
849
-
850
- ### Mode: With World Model and Tree of Thought
851
-
852
- **Step 1: Input Tokenization and Encoding**
853
-
854
- 1. **Tokenization:**
855
- \[
856
- \text{input\_ids} = \text{tokenizer.encode(query, return\_tensors='pt')}
857
- \]
858
- - Shape: \((\text{batch\_size}=1, \text{seq\_len})\)
859
-
860
- 2. **Encoding via Transformer:**
861
- \[
862
- \text{transformer\_output} = \text{model\_transformer}(input\_ids, input\_ids)
863
- \]
864
- - Shape: \((\text{batch\_size}=1, \text{seq\_len}, \text{d\_model})\)
865
-
866
- 3. **State Representation:**
867
- \[
868
- \text{initial\_representation} = \text{RepresentationNetwork}(\text{transformer\_output})[:, -1, :].unsqueeze(1)
869
- \]
870
- - Shape: \((\text{batch\_size}=1, 1, \text{state\_dim})\)
871
-
872
- 4. **State Initialization:**
873
- \[
874
- \text{initial\_state} = \text{State}(\text{representation}=\text{initial\_representation}, \text{dynamics\_network}=dynamics\_network, \text{action\_encoder}=action\_encoder, \text{thought\_node}=root\_thought\_node)
875
- \]
876
-
877
- **Step 2: MCTS Initialization and Root Node Evaluation**
878
-
879
- 1. **MCTS Instance Creation:**
880
- \[
881
- \text{mcts} = \text{MCTS}(\text{prediction\_network}, \text{dynamics\_network}, \text{action\_encoder}, \text{num\_iterations}=mcts\_iterations, \text{exploration\_constant}=exploration\_constant)
882
- \]
883
-
884
- 2. **Root Node Creation:**
885
- \[
886
- \text{root\_node} = \text{MCTSNode}(\text{state}=\text{initial\_state}, \text{thought\_node}=root\_thought\_node)
887
- \]
888
-
889
- 3. **Root Node Evaluation:**
890
- \[
891
- \text{value\_estimate} = \text{mcts.evaluate}(\text{root\_node})
892
- \]
893
-
894
- 4. **Backpropagation:**
895
- \[
896
- \text{mcts.backpropagate}(\text{root\_node}, \text{value\_estimate})
897
- \]
898
-
899
- **Step 3: MCTS Iterations with Beam Search**
900
-
901
- 1. **Beam Initialization:**
902
- \[
903
- \text{beam} = \left\{ \left( \text{root\_node}, \text{score} = 0, \text{cum\_entropy} = 0, \text{cum\_variance} = 0, \text{action\_sequence} = [] \right) \right\}
904
- \]
905
-
906
- 2. **Iterative Expansion:**
907
- - For each iteration up to `num_iterations`:
908
- - **Candidate Collection:**
909
- - For each node in the current beam:
910
- - **Leaf Evaluation:**
911
- - If `node.is_leaf()`:
912
- \[
913
- \text{value\_estimate} = \text{mcts.evaluate}(\text{node})
914
- \]
915
- \[
916
- \text{mcts.backpropagate}(\text{node}, \text{value\_estimate})
917
- \]
918
- - **Child Selection:**
919
- - If `node.children` is not empty:
920
- - Calculate `total_visits`:
921
- \[
922
- \text{total\_visits} = \sum_{\text{child} \in \text{node.children}} \text{child.visit\_count}
923
- \]
924
- - Select top `beam_size` actions based on UCB scores:
925
- \[
926
- \text{sorted\_children} = \text{sorted}(\text{node.children.items()}, \text{key}=\lambda \text{item}: \text{item}[1].ucb\_score(\text{total\_visits}, \text{exploration\_constant}), \text{reverse}=True)[:\text{beam\_size}]
927
- \]
928
- - **Action Sequence Prediction:**
929
- - For each selected action:
930
- - Initialize `current_node`, `current_sequence`, `current_score`, `current_entropy`, `current_variance`.
931
- - **Multi-Token Prediction:**
932
- - For each step in `n_tokens_predict`:
933
- - If `current_node.is_leaf()`:
934
- \[
935
- \text{value\_estimate} = \text{mcts.evaluate}(\text{current\_node})
936
- \]
937
- \[
938
- \text{mcts.backpropagate}(\text{current\_node}, \text{value\_estimate})
939
- \]
940
- - If `current_node.children` is empty, break.
941
- - Calculate `total_visits` for the new node.
942
- - Select the action with the highest UCB score:
943
- \[
944
- (\text{next\_action}, \text{next\_node}) = \text{max}(\text{current\_node.children.items()}, \text{key}=\lambda \text{item}: \text{item}[1].ucb\_score(\text{total\_visits}, \text{exploration\_constant}))
945
- \]
946
- - **Score Update:**
947
- \[
948
- \text{current\_score} += \frac{\text{next\_node.value\_sum}}{\text{next\_node.visit\_count}} \quad \text{if} \quad \text{next\_node.visit\_count} > 0 \quad \text{else} \quad 0
949
- \]
950
- - **Entropy and Variance Update:**
951
- \[
952
- \text{current\_entropy} += \text{next\_node.entropy}
953
- \]
954
- \[
955
- \text{current\_variance} += \text{next\_node.variance}
956
- \]
957
- - Append `next_action` to `current_sequence`.
958
- - Update `current_node` to `next_node`.
959
- - **Candidate Aggregation:**
960
- \[
961
- \text{all\_candidates.append}((\text{current\_node}, \text{current\_score}, \text{current\_entropy}, \text{current\_variance}, \text{current\_sequence}))
962
- \]
963
-
964
- - **Beam Pruning:**
965
- \[
966
- \text{beam} = \text{sorted}(\text{all\_candidates}, \text{key}=\lambda x: x[1] - 0.1 \times x[2] + 0.05 \times x[3], \text{reverse}=True)[:\text{beam\_size}]
967
- \]
968
-
969
- 3. **Termination:**
970
- - Stop early if no candidates remain or all beams have reached terminal nodes.
971
-
972
- 4. **Result Extraction:**
973
- \[
974
- \text{best\_sequence} = \text{beam}[0][4]
975
- \]
976
- - Return `best_sequence` as the generated sequence of actions (thoughts).
977
-
978
- ---
979
-
980
- ## Integration of Components
981
-
982
- The inference process seamlessly integrates multiple components to facilitate advanced reasoning:
983
-
984
- 1. **Transformer for Sequence Encoding and Generation:**
985
- - Processes input sequences and generates embeddings.
986
- - Facilitates beam search for direct text generation.
987
-
988
- 2. **World Model for State Representation and Dynamics:**
989
- - `RepresentationNetwork` encodes transformer outputs into state representations.
990
- - `DynamicsNetwork` predicts state transitions based on actions.
991
- - `PredictionNetwork` provides policy logits and value estimates.
992
- - `ActionEncoder` encodes actions into embeddings for state transitions.
993
-
994
- 3. **Tree of Thought for Structured Reasoning:**
995
- - Organizes possible thoughts/actions hierarchically.
996
- - Enables systematic exploration of reasoning paths.
997
-
998
- 4. **Monte Carlo Tree Search for Strategic Exploration:**
999
- - Utilizes ToT to explore potential reasoning paths.
1000
- - Balances exploration and exploitation using UCB scores.
1001
- - Incorporates beam search with multi-token prediction to handle complex, multi-step actions.
1002
-
1003
- **Workflow Integration:**
1004
-
1005
- - **Input Processing:** Input query is tokenized and encoded via the Transformer.
1006
- - **State Representation:** Encoded inputs are transformed into initial state representations.
1007
- - **MCTS Integration:** MCTS uses the World Model components to explore and evaluate possible thought sequences within the Tree of Thought.
1008
- - **Beam Search:** Multi-token beam search within MCTS ensures diverse and coherent exploration of actions.
1009
- - **Output Generation:** The best sequence of thoughts/actions is extracted and returned as the inference result.
1010
-
1011
- ---
1012
-
1013
- I am utilising Trees of Thought as a structure of how to structure sets of policies, and sequences of actions. These Tree structures provide the World Model a general thought structure and pattern, similarly to how humans create thought patterns for solving certain problems (e.g. understand, describe, analyse, etc).
1014
-
1015
- Here are some example Trees of Thought:
1016
- graph TD
1017
- A[Problem-Solving Process] --> B[Problem Identification]
1018
- A --> C[Problem Analysis]
1019
- A --> D[Solution Generation]
1020
- A --> E[Implementation]
1021
- A --> F[Evaluation and Adjustment]
1022
- B --> B1[Define the Problem]
1023
- B --> B2[Identify Stakeholders]
1024
- B --> B3[Determine Constraints]
1025
- B --> B4[Recognize Problem Type]
1026
- B --> B5[Historical Context]
1027
- C --> C1[Root Cause Analysis]
1028
- C --> C2[System Mapping]
1029
- C --> C3[Data Collection]
1030
- C --> C4[Impact Assessment]
1031
- C --> C5[Theoretical Framework]
1032
- D --> D1[Creative Problem Solving]
1033
- D --> D2[Analytical Approach]
1034
- D --> D3[Mathematical Computation]
1035
- D --> D4[Decision Making]
1036
- E --> E1[Action Planning]
1037
- E --> E2[Resource Allocation]
1038
- E --> E3[Change Management]
1039
- F --> F1[Verification]
1040
- F --> F2[Performance Metrics]
1041
- F --> F3[Feedback Loops]
1042
- F --> F4[Continuous Improvement]
1043
- C3 --> C3a[Quantitative Data]
1044
- C3 --> C3b[Qualitative Data]
1045
- C3 --> C3c[Data Validation]
1046
- D1 --> D1a[Divergent Thinking]
1047
- D1 --> D1b[Convergent Thinking]
1048
- D1 --> D1c[Lateral Thinking]
1049
- D2 --> D2a[Logical Reasoning]
1050
- D2 --> D2b[Critical Analysis]
1051
- D2 --> D2c[Systems Thinking]
1052
- D3 --> D3a[Basic Operations]
1053
- D3 --> D3b[Advanced Operations]
1054
- D3 --> D3c[Computational Methods]
1055
- D4 --> D4a[Decision Trees]
1056
- D4 --> D4b[Multi-Criteria Analysis]
1057
- D4 --> D4c[Probabilistic Reasoning]
1058
- G[Cross-Cutting Considerations] --> G1[Ethical Framework]
1059
- G --> G2[Stakeholder Management]
1060
- G --> G3[Interdisciplinary Connections]
1061
- G --> G4[Technological Integration]
1062
- G --> G5[Emotional Intelligence]
1063
- G --> G6[Collaborative Problem Solving]
1064
- G1 --> G1a[Value-based Decision Making]
1065
- G1 --> G1b[Long-term Consequences]
1066
- G2 --> G2a[Direct Stakeholders]
1067
- G2 --> G2b[Indirect Stakeholders]
1068
- G2 --> G2c[Conflicting Interests]
1069
- G3 --> G3a[Related Fields]
1070
- G3 --> G3b[Cross-disciplinary Impact]
1071
- G4 --> G4a[AI-assisted Problem Solving]
1072
- G4 --> G4b[Data-driven Insights]
1073
- G4 --> G4c[Digital Collaboration Tools]
1074
- G5 --> G5a[Self-Awareness]
1075
- G5 --> G5b[Empathy]
1076
- G5 --> G5c[Stress Management]
1077
- G6 --> G6a[Team Dynamics]
1078
- G6 --> G6b[Communication Strategies]
1079
- G6 --> G6c[Conflict Resolution]
1080
- H[Computational Considerations] --> H1[CPU Operations]
1081
- H --> H2[GPU Parallelization]
1082
- H --> H3[Floating-Point Precision]
1083
- I[Order of Operations] --> I1[Parentheses]
1084
- I --> I2[Exponents]
1085
- I --> I3[Multiplication and Division]
1086
- I --> I4[Addition and Subtraction]
1087
- J[Critical Thinking] --> J1[Assumptions Questioning]
1088
- J --> J2[Bias Recognition]
1089
- K[Future Perspective] --> K1[Short-term Projections]
1090
- K --> K2[Long-term Scenarios]
1091
- K --> K3[Potential Impacts]
1092
- L[Learning and Adaptation] --> L1[Reflective Practice]
1093
- L --> L2[Knowledge Transfer]
1094
- L --> L3[Adaptive Problem Solving]
1095
-
1096
-
1097
- graph TD
1098
- A[Meta-Cognitive Strategies] --> B[Creative Problem Solving]
1099
- A --> C[Systems Thinking]
1100
- A --> D[Decision Making]
1101
- A --> E[Emotional Intelligence]
1102
- A --> F[Collaborative Problem Solving]
1103
- B --> B1[Divergent Thinking]
1104
- B --> B2[Convergent Thinking]
1105
- B --> B3[Lateral Thinking]
1106
- C --> C1[Holistic Perspective]
1107
- C --> C2[Feedback Loops]
1108
- C --> C3[Emergent Properties]
1109
- D --> D1[Decision Trees]
1110
- D --> D2[Multi-Criteria Decision Analysis]
1111
- D --> D3[Probabilistic Reasoning]
1112
- E --> E1[Self-Awareness]
1113
- E --> E2[Empathy]
1114
- E --> E3[Stress Management]
1115
- F --> F1[Team Dynamics]
1116
- F --> F2[Communication Strategies]
1117
- F --> F3[Conflict Resolution]
1118
- G[Learning and Adaptation]
1119
- A --> G
1120
- G --> G1[Reflective Practice]
1121
- G --> G2[Knowledge Transfer]
1122
- G --> G3[Adaptive Problem Solving]
1123
- H[Ethical Framework]
1124
- A --> H
1125
- H --> H1[Value-based Decision Making]
1126
- H --> H2[Stakeholder Analysis]
1127
- H --> H3[Long-term Consequences]
1128
- I[Technological Integration]
1129
- A --> I
1130
- I --> I1[AI-assisted Problem Solving]
1131
- I --> I2[Data-driven Insights]
1132
- I --> I3[Digital Collaboration Tools]
1133
-
1134
-
1135
- ## Requirements
1136
-
1137
- This code requires:
1138
- - Python 3.7+
1139
- - `torch>=1.7.1`
1140
- - `transformers`
1141
- - `datasets`
1142
- - `argparse`
1143
-
1144
-
1145
- ## Citation
1146
-
1147
- If you use this model in your research, please cite the author.
 
1
  ---
2
  license: apache-2.0
3
 
4
+ ---
5
+ ---
6
+ license: apache-2.0
7
  ---
8
 
9
+ # Model Card for LightBulb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
+ ## Overview
 
 
 
 
 
 
 
 
 
 
12
 
13
+ **LightBulb** is an advanced framework designed to train and utilize language models and autonomous web search agents. It integrates hierarchical reinforcement learning, Monte Carlo Tree Search (MCTS), and Tree of Thought (ToT) architectures to enable sophisticated reasoning and decision-making capabilities. The framework supports both training and inference for language models, web search agents, and comprehensive world models.
14
 
15
+ ## Installation
16
 
17
+ To install the necessary dependencies, run:
18
 
 
19
  ```bash
20
+ pip install huggingface_hub torch transformers datasets argparse
21
  ```
22
 
23
+ ## Getting Started
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
+ ### Download the Repository
 
 
 
26
 
27
+ Use the `huggingface_hub` to download the repository:
 
 
 
28
 
29
+ ```python
30
+ from huggingface_hub import snapshot_download
 
31
 
32
+ # Download the repository
33
+ repo_path = snapshot_download("RobbiePasquale/lightbulb")
 
 
34
 
35
+ print(f"Repository downloaded to: {repo_path}")
36
+ ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
+ ## Main Features
39
 
40
+ LightBulb provides six primary functionalities, each accessible via the `main_menu.py` script using command-line arguments.
 
 
41
 
42
+ ### 1. Train a Web Search Agent
43
 
44
+ **Description:**
45
+ Trains an autonomous web search agent that navigates the web, gathers relevant content, and learns to summarize and generate responses based on user queries.
46
 
47
+ **Usage:**
48
  ```bash
49
  python main_menu.py --task train_agent
50
  ```
51
 
52
+ **Key Components:**
53
+ - **Hierarchical Reinforcement Learning (HRL):** Manages high-level (Manager) and low-level (Worker) decision-making.
54
+ - **Monte Carlo Tree Search (MCTS):** Guides the agent through complex decision trees.
55
+ - **RAGSummarizer:** Summarizes retrieved web content.
56
+ - **Knowledge Base:** Stores and retrieves information to inform future queries.
57
 
58
+ ### 2. Use a Web Search Agent (Inference)
59
 
60
+ **Description:**
61
+ Utilizes the trained web search agent to process queries, perform web searches, and generate summarized responses.
62
 
63
+ **Usage:**
 
64
  ```bash
65
  python main_menu.py --task test_agent
66
  ```
67
 
68
+ **Options:**
69
+ - **Interactive Mode:**
70
+ ```bash
71
+ python main_menu.py --task test_agent
72
+ ```
73
+ - **Single Query Mode:**
74
+ ```bash
75
+ python main_menu.py --task test_agent --query "Your query here"
76
+ ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
+ ### 3. Train a Language Model
79
 
80
+ **Description:**
81
+ Trains a Language Model (LLM) and World Model using datasets from Hugging Face, enabling the model to handle complex reasoning and long sequences.
82
 
83
+ **Usage:**
84
+ ```bash
85
+ python main_menu.py --task train_llm_world --model_name gpt2 --dataset_name wikitext --num_epochs 5 --batch_size 8 --max_length 256
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  ```
 
87
 
88
+ **Key Arguments:**
89
+ - `--model_name`: Pretrained model (e.g., `gpt2`, `bert`).
90
+ - `--dataset_name`: Dataset from Hugging Face (e.g., `wikitext`).
91
+ - `--num_epochs`: Number of training epochs.
92
+ - `--batch_size`: Number of samples per batch.
93
+ - `--max_length`: Maximum sequence length.
94
 
95
+ ### 4. Inference Using Language Model with Multi-Token Prediction, Beam Search, and MCTS
 
 
 
 
96
 
97
+ **Description:**
98
+ Generates responses using the trained language model, leveraging multi-token prediction, beam search, and MCTS for enhanced coherence and strategic reasoning.
99
 
100
+ **Usage:**
101
+ ```bash
102
+ python main_menu.py --task inference_llm --query "Your query here"
103
  ```
104
 
105
+ **Process:**
106
+ 1. **Multi-Token Prediction:** Predicts multiple tokens at each step to improve generation speed.
107
+ 2. **Beam Search:** Maintains multiple candidate sequences to ensure diverse and high-quality outputs.
108
+ 3. **MCTS Integration:** Uses MCTS to evaluate and select the most promising token sequences based on policy and value estimates.
109
 
110
+ ### 5. Train a Language World Model
 
 
 
 
111
 
112
+ **Description:**
113
+ Develops a comprehensive World Model that encapsulates state representations, dynamics, and prediction networks to simulate and predict state transitions within the Tree of Thought framework.
114
 
115
+ **Usage:**
116
+ ```bash
117
+ python main_menu.py --task train_world_model --additional_args
118
  ```
 
 
 
 
 
 
 
 
 
119
 
120
+ **Key Components:**
121
+ - **Representation Network:** Encodes Transformer outputs into state representations.
122
+ - **Dynamics Network:** Predicts next states based on current states and actions.
123
+ - **Prediction Network:** Generates policy logits and value estimates.
124
+ - **Action Encoder:** Encodes actions into embeddings for state transitions.
125
 
126
+ ### 6. Inference with Language World Model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
 
128
+ **Description:**
129
+ Utilizes the trained World Model to perform advanced reasoning and generate responses based on structured thought processes and state simulations.
130
 
131
+ **Usage:**
132
+ ```bash
133
+ python main_menu.py --task inference_world_model --query "Your query here"
 
 
134
  ```
135
 
136
+ **Features:**
137
+ - **Tree of Thought (ToT):** Structures reasoning paths hierarchically.
138
+ - **Beam Search with MCTS:** Enhances decision-making by balancing exploration and exploitation.
139
+ - **Integration with Knowledge Base:** Leverages stored information for informed responses.
140
 
141
+ ### 7. Inference with World Model, Tree of Thought, and Multi-Token Beam Search
142
 
143
+ **Description:**
144
+ Executes inference using the World Model integrated with ToT and multi-token beam search for highly coherent and contextually rich outputs.
145
 
146
+ **Usage:**
147
+ ```bash
148
+ python main_menu.py --task advanced_inference --query "Your complex query here"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
  ```
 
 
 
 
 
 
 
 
 
 
 
 
 
150
 
151
+ **Process:**
152
+ 1. **State Initialization:** Converts input queries into state representations.
153
+ 2. **MCTS with Beam Search:** Explores multiple reasoning paths simultaneously.
154
+ 3. **Thought Sequence Generation:** Produces a sequence of interconnected thoughts/actions.
155
+ 4. **Final Response Generation:** Synthesizes the best thought path into a coherent response.
156
 
157
+ ## General Arguments
 
 
158
 
159
+ | Argument | Required | Description | Default |
160
+ |--------------------|----------|--------------------------------------------------------------------------------------------------|---------------------|
161
+ | `--task` | Yes | Specifies the task to run (`train_llm_world`, `train_agent`, `test_agent`, etc.). | None |
162
+ | `--model_name` | No | Pretrained model name for LLM (`gpt2`, `bert`, etc.) or a custom model path. | `gpt2` |
163
+ | `--dataset_name` | No | Name of the dataset from Hugging Face for training the LLM and World Model (e.g., `wikitext`). | `wikitext` |
164
+ | `--dataset_config` | No | Configuration name for the dataset. | `wikitext-2-raw-v1` |
165
+ | `--batch_size` | No | Number of samples per batch during training. | `4` |
166
+ | `--num_epochs` | No | Number of training epochs. | `3` |
167
+ | `--max_length` | No | Maximum sequence length for training/inference. | `128` |
168
+ | `--mode` | No | Mode for LLM and World Model (`train`, `inference`). | `train` |
169
+ | `--query` | No | Query input for `test_agent` when running a single query. | `''` (empty) |
170
 
171
+ ## Requirements
 
 
 
 
 
 
172
 
173
+ - **Python:** 3.7+
174
+ - **Libraries:**
175
+ - `torch>=1.7.1`
176
+ - `transformers`
177
+ - `datasets`
178
+ - `argparse`
179
+ - `huggingface_hub`
180
 
181
+ ## Usage Examples
182
 
183
+ ### Training the Language Model and World Model
 
 
 
184
 
185
+ ```bash
186
+ python main_menu.py --task train_llm_world --model_name gpt2 --dataset_name wikitext --num_epochs 5 --batch_size 8 --max_length 256
 
 
 
 
 
 
187
  ```
 
188
 
189
+ ### Training the Web Search Agent
 
 
 
 
 
 
 
 
 
 
190
 
191
+ ```bash
192
+ python main_menu.py --task train_agent
193
  ```
 
 
 
 
 
194
 
195
+ ### Testing the Web Search Agent in Interactive Mode
 
 
 
 
 
196
 
197
+ ```bash
198
+ python main_menu.py --task test_agent
199
  ```
 
 
 
 
 
 
 
 
 
 
 
 
200
 
201
+ ### Testing the Web Search Agent with a Single Query
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
202
 
203
+ ```bash
204
+ python main_menu.py --task test_agent --query "What are the impacts of renewable energy on global sustainability?"
205
  ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
 
207
+ ### Advanced Inference with World Model and Tree of Thought
208
 
209
+ ```bash
210
+ python main_menu.py --task advanced_inference --query "Analyze the economic effects of artificial intelligence in the next decade."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
  ```
 
212
 
213
+ ## Citation
 
 
 
 
 
 
214
 
215
+ If you use LightBulb in your research, please cite the author:
 
216
 
 
217
  ```
218
+ @misc{RobbiePasquale_lightbulb,
219
+ author = {Robbie Pasquale},
220
+ title = {LightBulb: An Autonomous Web Search and Language Model Framework},
221
+ year = {2024},
222
+ publisher = {Huggingface},
223
+ howpublished = {\url{https://huggingface.co/RobbiePasquale/lightbulb}},
224
+ }
 
 
225
  ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
226
 
 
227
 
228
+ ## License
 
 
 
229
 
230
+ This project is licensed under the Apache 2.0 License.
 
 
 
231
 
232
  ---
233
 
234
+ For more detailed information on each component and advanced configurations, please refer to the [documentation](https://huggingface.co/RobbiePasquale/lightbulb).