Update README.md
Browse files
README.md
CHANGED
|
@@ -1,1147 +1,234 @@
|
|
| 1 |
---
|
| 2 |
license: apache-2.0
|
| 3 |
|
|
|
|
|
|
|
|
|
|
| 4 |
---
|
| 5 |
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
The `main_menu.py` script is the primary entry point for choosing and executing one of three tasks:
|
| 9 |
-
1. **Training the LLM and World Model**: `train_llm_world`
|
| 10 |
-
2. **Training the Search Agent**: `train_agent`
|
| 11 |
-
3. **Testing the Tree of Thought Search Agent**: `test_agent`
|
| 12 |
-
|
| 13 |
-
Each task has unique functionalities and configurations. This script uses command-line arguments to specify the desired task and additional options, giving users the ability to tailor the execution according to their needs.
|
| 14 |
-
|
| 15 |
-
### Running the Main Menu
|
| 16 |
-
|
| 17 |
-
To run the main menu, use the following command in the terminal:
|
| 18 |
-
```bash
|
| 19 |
-
python main_menu.py --task <task_name> [additional arguments]
|
| 20 |
-
```
|
| 21 |
-
|
| 22 |
-
Replace `<task_name>` with one of the following:
|
| 23 |
-
- `train_llm_world` - Train the LLM (Language Model) and World Model.
|
| 24 |
-
- `train_agent` - Train the Search Agent with an interactive Twisted-based process.
|
| 25 |
-
- `test_agent` - Test the Tree of Thought Search Agent, with the option of an interactive session or a single query.
|
| 26 |
-
|
| 27 |
-
### General Arguments
|
| 28 |
-
|
| 29 |
-
The script supports a set of command-line arguments to customize each task. Here’s an overview of all possible arguments:
|
| 30 |
|
| 31 |
-
|
| 32 |
-
|------------------------|----------|-------------------------------------------------------------------------------------------------------------------------------------------------------|--------------------------|
|
| 33 |
-
| `--task` | Yes | Specifies the task to run. Choose from `train_llm_world`, `train_agent`, or `test_agent`. | None |
|
| 34 |
-
| `--model_name` | No | Pretrained model name for LLM. Options include `gpt2`, `bert`, etc., or a custom model path. | `gpt2` |
|
| 35 |
-
| `--dataset_name` | No | Name of the dataset from Hugging Face Datasets for training the LLM and World Model (e.g., `wikitext`). | `wikitext` |
|
| 36 |
-
| `--dataset_config` | No | Dataset configuration name for specifying different versions or configurations of the dataset. | `wikitext-2-raw-v1` |
|
| 37 |
-
| `--batch_size` | No | Number of samples processed in a single forward/backward pass. Increasing the batch size can speed up training but requires more memory. | `4` |
|
| 38 |
-
| `--num_epochs` | No | Number of times to iterate over the training dataset during model training. More epochs generally improve learning but can lead to overfitting. | `3` |
|
| 39 |
-
| `--max_length` | No | Maximum sequence length for training/inference. Truncates or pads sequences to this length to maintain consistency in training. | `128` |
|
| 40 |
-
| `--mode` | No | Specifies the mode for the LLM and World Model. Use `train` for training and `inference` for generating responses. | `train` |
|
| 41 |
-
| `--query` | No | Query input for `test_agent` when running a single query instead of an interactive session. | `''` (empty) |
|
| 42 |
|
| 43 |
-
|
| 44 |
|
| 45 |
-
|
| 46 |
|
| 47 |
-
|
| 48 |
|
| 49 |
-
#### Example Usage
|
| 50 |
```bash
|
| 51 |
-
|
| 52 |
```
|
| 53 |
|
| 54 |
-
|
| 55 |
-
- **`--model_name`**: Name of the pretrained model to use for language model training. You can specify a model name (like `gpt2`, `bert`, etc.) or a path to a custom model. This argument affects the model architecture and tokenization style.
|
| 56 |
-
|
| 57 |
-
- **`--dataset_name`**: Specifies the dataset from Hugging Face’s Datasets library to train the model. Options include `wikitext`, `imdb`, `squad`, etc. You can also use a custom dataset by specifying its path.
|
| 58 |
-
|
| 59 |
-
- **`--dataset_config`**: Defines the configuration of the dataset, which may be different versions or variations of the dataset. For example, `wikitext` includes configurations such as `wikitext-2-raw-v1`. The configuration will affect the format and content of the data.
|
| 60 |
-
|
| 61 |
-
- **`--batch_size`**: The number of samples per batch. A larger batch size requires more memory but can improve training speed. You might need to reduce the batch size if memory is limited.
|
| 62 |
-
|
| 63 |
-
- **`--num_epochs`**: The number of complete passes through the training dataset. More epochs can improve the model’s ability to learn but may lead to overfitting if too high.
|
| 64 |
-
|
| 65 |
-
- **`--max_length`**: Limits the maximum length of the input sequence. Truncated sequences will be cut off, and shorter sequences will be padded. This affects both training and inference.
|
| 66 |
-
|
| 67 |
-
- **`--mode`**: Defines the task to be performed. Choose `train` to start training the model. If set to `inference`, the model generates text based on the input.
|
| 68 |
-
|
| 69 |
-
### 2. Training the Search Agent (`train_agent`)
|
| 70 |
-
Here's a detailed breakdown of your search agent, covering training, inference, and the functionality of each component. This overview will also highlight how the agent saves LLM training data, its modular structure, and the role of each module.
|
| 71 |
-
|
| 72 |
-
---
|
| 73 |
-
|
| 74 |
-
## Overview of the AutonomousWebAgent
|
| 75 |
-
|
| 76 |
-
The `AutonomousWebAgent` is a sophisticated, multi-component search and retrieval agent designed to navigate the web, gather relevant content, and perform summarization and generation based on user queries. This agent integrates reinforcement learning (RL), Monte Carlo Tree Search (MCTS), a Retrieva-Augmented Generation (RAG) Summarizer, and a Hierarchical Reinforcement Learning (HRL) architecture to select, execute, and optimize its actions based on past experiences.
|
| 77 |
-
|
| 78 |
-
### Key Components
|
| 79 |
-
|
| 80 |
-
1. **Prioritized Experience Replay**:
|
| 81 |
-
- The agent uses a `PrioritizedReplayMemory` and a `SumTree` to prioritize and store experiences (transitions between states).
|
| 82 |
-
- The `SumTree` structure maintains a binary tree where each parent node's value is the sum of its children, helping to efficiently store, update, and retrieve experiences based on priority.
|
| 83 |
-
- These experiences are critical in training both high-level (manager) and low-level (worker) components through prioritized sampling during replay, allowing the model to focus on more significant transitions.
|
| 84 |
|
| 85 |
-
|
| 86 |
-
- HRL is employed to allow a **Manager** (high-level) model to select options, which are then executed by a **Worker** (low-level) model. The `ManagerModel` selects tasks (such as searching, summarizing, or generating), while the `WorkerModel` determines specific actions to take.
|
| 87 |
-
- The manager and worker use LSTM networks with fully connected layers, and each has its own replay memory and optimization process.
|
| 88 |
-
- The Manager focuses on broad decisions and options, while the Worker operates on specific actions, enabling a layered approach to decision-making.
|
| 89 |
|
| 90 |
-
|
| 91 |
-
- The `RAGSummarizer` leverages a pre-trained language model (e.g., GPT-2) for summarizing, and a SentenceTransformer for embedding-based retrieval. This module breaks down the input content into chunks, retrieves relevant sections based on cosine similarity with the query, and generates a coherent summary.
|
| 92 |
-
- Additionally, it implements a Least Recently Used (LRU) cache to avoid redundant computation and enhance efficiency, along with persistent storage for cache data.
|
| 93 |
-
- Summarized results are stored, and this module contributes directly to the generation of LLM training data.
|
| 94 |
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
- It is utilized in the HRL architecture, specifically by the Worker for evaluating actions and by the Manager in long-term decision-making.
|
| 98 |
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
- It supports saving and loading document embeddings, so the agent can retrieve relevant information for new queries from previously collected knowledge.
|
| 102 |
-
- Adding and retrieving from the knowledge base enriches the agent’s context and allows it to store and use information from past experiences to inform current tasks.
|
| 103 |
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
- Nodes in the tree represent states (possible query refinements), and child nodes represent possible expansions (e.g., related query variations).
|
| 107 |
-
- MCTS utilizes a `select`, `expand`, `simulate`, and `backpropagate` strategy to iteratively refine queries, scoring them based on relevance and other metrics to converge on optimal searches.
|
| 108 |
-
- It also integrates RL by backpropagating rewards based on the ranking score from retrieved results.
|
| 109 |
-
|
| 110 |
-
7. **Ranking Model**:
|
| 111 |
-
- The ranking model, built with a neural network and the `SentenceTransformer`, ranks search results based on various features such as cosine similarity with the query, content length, keyword overlap, and domain authority.
|
| 112 |
-
- This model assigns scores to results, which are then used to guide the MCTS process by enhancing the combined reward with ranking scores.
|
| 113 |
-
|
| 114 |
-
8. **Tree of Thought (ToT) Search**:
|
| 115 |
-
- This module enhances the agent's capability to generate a series of interconnected thoughts, exploring different perspectives or angles on a given query.
|
| 116 |
-
- `ToTNode` and `ToTSearch` classes enable the agent to generate thoughts, evaluate them, and navigate through them as a tree, considering various potential paths to best answer the query.
|
| 117 |
-
- It combines MCTS and RAG to synthesize responses based on the generated thought paths.
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
### Training Process
|
| 121 |
-
|
| 122 |
-
The training process for the agent involves episodic learning, where it interacts with various queries from a predefined list. Each query initiates an episode, and the agent performs actions based on its learned policy:
|
| 123 |
-
|
| 124 |
-
1. **Search and Summarization**:
|
| 125 |
-
- The agent performs search operations, gathering relevant content from online sources using the MCTS and Ranking Model for prioritization.
|
| 126 |
-
- Summarization is then carried out on the retrieved content, with relevant information stored in the LLM training data.
|
| 127 |
-
|
| 128 |
-
2. **Knowledge Base and LLM Training Data Storage**:
|
| 129 |
-
- Throughout the training process, the agent stores retrieved documents, query results, and summaries in its knowledge base and saves training data for future LLM fine-tuning.
|
| 130 |
-
- The data is saved in JSONL format and includes metadata such as query terms, source links, and summaries, making it valuable for training language models.
|
| 131 |
-
|
| 132 |
-
3. **Experience Replay**:
|
| 133 |
-
- Both the manager and worker models engage in prioritized experience replay, sampling from the stored experiences in the SumTree based on TD-errors.
|
| 134 |
-
- Replay is essential for reinforcing successful transitions and updating the models' policies over time.
|
| 135 |
-
|
| 136 |
-
4. **Reward Calculation and Backpropagation**:
|
| 137 |
-
- Rewards are calculated based on ranking scores, cosine similarity with the query, and other custom factors (e.g., query complexity, state length).
|
| 138 |
-
- These rewards are backpropagated through the MCTS and used to update the models' decision-making processes, ensuring continuous learning and adaptation.
|
| 139 |
-
|
| 140 |
-
### Inference Process
|
| 141 |
-
|
| 142 |
-
During inference:
|
| 143 |
-
- The agent accepts a query, and the Manager model selects a high-level action based on its policy (e.g., search, summarize, or generate).
|
| 144 |
-
- Once an option is chosen, the Worker model executes the corresponding low-level actions. For example, in a search operation, it leverages MCTS to refine the query, retrieves relevant web content, and processes it with the RAGSummarizer.
|
| 145 |
-
- Each inference step is augmented by the agent's existing knowledge base, enabling it to produce more informed and contextually rich responses. Additionally, if Tree of Thought (ToT) is employed, the agent synthesizes a coherent and comprehensive answer based on the thought path.
|
| 146 |
|
| 147 |
-
|
| 148 |
|
| 149 |
-
|
| 150 |
-
- `save_worker_model` and `save_manager_model` functions save the worker and manager models independently.
|
| 151 |
-
- The `save` method preserves the overall state of the agent, which includes its knowledge base, replay memories, and models. This facilitates model reusability and persistent storage, enabling the agent to resume from saved states during training or deployment.
|
| 152 |
|
| 153 |
-
|
| 154 |
|
| 155 |
-
|
| 156 |
-
|
| 157 |
|
| 158 |
-
|
| 159 |
```bash
|
| 160 |
python main_menu.py --task train_agent
|
| 161 |
```
|
| 162 |
|
| 163 |
-
|
| 164 |
-
- **
|
| 165 |
-
|
| 166 |
-
- **
|
|
|
|
| 167 |
|
| 168 |
-
###
|
| 169 |
|
| 170 |
-
|
|
|
|
| 171 |
|
| 172 |
-
|
| 173 |
-
Interactive Mode:
|
| 174 |
```bash
|
| 175 |
python main_menu.py --task test_agent
|
| 176 |
```
|
| 177 |
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
- **Input**: In interactive mode, enter a query and press Enter. The agent will respond based on its training and the Tree of Thought methodology, traversing different thought paths to generate a response.
|
| 188 |
-
|
| 189 |
-
- **Exiting**: To exit the interactive session, type `quit` and press Enter. The agent will then save any new knowledge it has gained and exit the program.
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
-------------------------------------------------------------------------------------------------------------------------------------------------
|
| 193 |
-
## World Model
|
| 194 |
-
|
| 195 |
-
2. **Representation Network**: This module encodes the Transformer output to generate a state representation, reducing dimensionality and making it suitable for further processing.
|
| 196 |
-
3. **Dynamics Network**: This module predicts the next state given a current state and an action. It uses layer normalization and a GELU activation function.
|
| 197 |
-
4. **Prediction Network**: Predicts both the policy logits and value estimates for a given state. It outputs the probabilities of different actions as well as a single scalar value.
|
| 198 |
-
5. **MCTS**: This module performs Monte Carlo Tree Search to evaluate the quality of actions over multiple iterations. It expands nodes based on the policy logits from the Prediction Network and simulates the reward by backpropagating value estimates.
|
| 199 |
-
6. **PPO Agent**: Uses policy and value estimates to calculate PPO loss, which updates the policy while maintaining the constraint on the KL divergence between old and new policies.
|
| 200 |
-
|
| 201 |
-
The transformer strategically utilises beam search as well as multi token prediction, in order to enrich the encoding from the representation network.
|
| 202 |
-
|
| 203 |
-
A generated sequence of tokens is an action, for example if a token is t, then an action is:
|
| 204 |
-
|
| 205 |
-
a_1= {t1,...,tN}
|
| 206 |
-
|
| 207 |
-
then a policy is a sequence of actions:
|
| 208 |
|
| 209 |
-
|
| 210 |
|
| 211 |
-
|
|
|
|
| 212 |
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
## Training Details
|
| 219 |
-
|
| 220 |
-
The model is trained with the following components and techniques:
|
| 221 |
-
|
| 222 |
-
### Training Procedure
|
| 223 |
-
- **Data Loading**: The data is tokenized and prepared with attention to padding and truncation. Text data is grouped into sequences of fixed length for efficient training.
|
| 224 |
-
- **Optimization**: Training uses an **AdamW** optimizer with **CosineAnnealingLR** scheduler for learning rate adjustments. The **Gradient Scaler** helps prevent overflow when training with mixed precision.
|
| 225 |
-
- **Gradient Accumulation**: Since the model can be computationally heavy, gradients are accumulated over several steps to reduce memory usage.
|
| 226 |
-
- **Loss Functions**: The training process leverages a comprehensive set of custom loss functions:
|
| 227 |
-
|
| 228 |
-
**1. InfoNCE Loss (Info Noise Contrastive Estimation Loss):**
|
| 229 |
-
Definition: This loss function is used for contrastive learning, encouraging similar samples to be close in the embedding space while pushing dissimilar samples apart.
|
| 230 |
-
|
| 231 |
-
Formula:
|
| 232 |
-
```
|
| 233 |
-
L_InfoNCE = -log[ exp(sim(z_i, z_j) / τ) / Σ_k exp(sim(z_i, z_k) / τ) ]
|
| 234 |
-
```
|
| 235 |
-
where sim() is the cosine similarity, τ is the temperature parameter, z_i and z_j are paired samples, and the sum in the denominator is over all other samples in the batch.
|
| 236 |
-
|
| 237 |
-
**2. Covariance Regularization:**
|
| 238 |
-
Definition: This regularization term encourages the learned representations to have uncorrelated dimensions, promoting more diverse and informative embeddings.
|
| 239 |
-
|
| 240 |
-
Formula:
|
| 241 |
-
```
|
| 242 |
-
L_cov = λ * (Σ_i,j (Cov(i,j)^2 - diag(Cov(i,j))^2))
|
| 243 |
-
```
|
| 244 |
-
where Cov is the covariance matrix of the embeddings, and λ is a regularization coefficient.
|
| 245 |
-
|
| 246 |
-
**3. Dynamics Performance Loss:**
|
| 247 |
-
Definition: This loss measures the accuracy of predicted next states while also encouraging diverse predictions.
|
| 248 |
-
|
| 249 |
-
Formula:
|
| 250 |
-
```
|
| 251 |
-
L_dynamics = MSE(true_next_state, predicted_next_state) + λ * Var(predicted_next_state)
|
| 252 |
-
```
|
| 253 |
-
where MSE is the mean squared error, Var is the variance, and λ is a weighting factor.
|
| 254 |
-
|
| 255 |
-
**4. Thought Consistency Loss:**
|
| 256 |
-
Definition: This loss encourages consistency between true next states and perturbed next states.
|
| 257 |
-
|
| 258 |
-
Formula:
|
| 259 |
-
```
|
| 260 |
-
L_consistency = MSE(true_next_state, perturbed_next_state)
|
| 261 |
-
```
|
| 262 |
-
|
| 263 |
-
**5. Policy Value Joint Loss:**
|
| 264 |
-
Definition: This loss combines policy and value losses for reinforcement learning tasks.
|
| 265 |
-
|
| 266 |
-
Formula:
|
| 267 |
-
```
|
| 268 |
-
L_joint = CrossEntropy(policy_logits, true_policy) + λ * MSE(value_pred, true_value)
|
| 269 |
```
|
| 270 |
-
where λ is a weighting factor balancing policy and value losses.
|
| 271 |
|
| 272 |
-
**
|
| 273 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 274 |
|
| 275 |
-
|
| 276 |
-
```
|
| 277 |
-
R_diversity = λ * Σ_i,j (cos_sim(a_i, a_j)^2)
|
| 278 |
-
```
|
| 279 |
-
where cos_sim is the cosine similarity between action embeddings, and λ is a scaling factor.
|
| 280 |
|
| 281 |
-
**
|
| 282 |
-
|
| 283 |
|
| 284 |
-
|
| 285 |
-
```
|
| 286 |
-
|
| 287 |
```
|
| 288 |
|
| 289 |
-
**
|
| 290 |
-
|
|
|
|
|
|
|
| 291 |
|
| 292 |
-
|
| 293 |
-
```
|
| 294 |
-
R_exploration = λ * mean(Σ_a (1 / (visit_count(a) + 1)))
|
| 295 |
-
```
|
| 296 |
-
where λ is a scaling factor.
|
| 297 |
|
| 298 |
-
**
|
| 299 |
-
|
| 300 |
|
| 301 |
-
|
| 302 |
-
```
|
| 303 |
-
|
| 304 |
```
|
| 305 |
-
where KL is the Kullback-Leibler divergence.
|
| 306 |
-
|
| 307 |
-
L_KL is the KL divergence loss
|
| 308 |
-
old_policy and new_policy are probability distributions
|
| 309 |
-
i represents each possible outcome or action
|
| 310 |
-
n is the total number of possible outcomes or actions
|
| 311 |
-
|
| 312 |
-
### Evaluation
|
| 313 |
-
After each epoch, the model is evaluated on the validation set, computing the average loss over the dataset. The evaluation function utilizes the same loss functions as training but does not backpropagate, allowing it to be run in inference mode.
|
| 314 |
|
| 315 |
-
|
| 316 |
-
|
|
|
|
|
|
|
|
|
|
| 317 |
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
### World Model Components
|
| 321 |
-
|
| 322 |
-
The World Model encapsulates components that model state representations, dynamics, predictions, and action encodings. These components interact with the Transformer to simulate and predict state transitions within the Tree of Thought framework.
|
| 323 |
-
|
| 324 |
-
#### RepresentationNetwork
|
| 325 |
-
|
| 326 |
-
**Function:**
|
| 327 |
-
Transforms the transformer's output embeddings into a compact state representation suitable for modeling and prediction tasks.
|
| 328 |
-
|
| 329 |
-
**Mathematical Operation:**
|
| 330 |
-
```
|
| 331 |
-
\[
|
| 332 |
-
\text{State} = \text{LayerNorm}\left(\text{Linear}(d_{\text{model}} \rightarrow d_{\text{state}})\left(\text{Linear}(vocab\_dim \rightarrow d_{\text{model}})(\text{Transformer Output})\right)\right)
|
| 333 |
-
\]
|
| 334 |
-
```
|
| 335 |
-
**Explanation:**
|
| 336 |
-
Sequential linear transformations project high-dimensional embeddings into a lower-dimensional state space, followed by layer normalization for stability.
|
| 337 |
-
|
| 338 |
-
#### DynamicsNetwork
|
| 339 |
|
| 340 |
-
**
|
| 341 |
-
|
| 342 |
|
| 343 |
-
**
|
| 344 |
-
```
|
| 345 |
-
|
| 346 |
-
\text{Next State} = \text{DynamicsNetwork}(\text{Current State}, \text{Action Embedding})
|
| 347 |
-
\]
|
| 348 |
```
|
| 349 |
|
| 350 |
-
**
|
| 351 |
-
|
|
|
|
|
|
|
| 352 |
|
| 353 |
-
|
| 354 |
|
| 355 |
-
**
|
| 356 |
-
|
| 357 |
|
| 358 |
-
**
|
| 359 |
-
```
|
| 360 |
-
|
| 361 |
-
(\text{Policy Logits}, \text{Value Estimate}) = \text{PredictionNetwork}(\text{State})
|
| 362 |
-
\]
|
| 363 |
-
```
|
| 364 |
-
**Explanation:**
|
| 365 |
-
- **Policy Logits:** Used to derive action probabilities via softmax.
|
| 366 |
-
- **Value Estimate:** Represents the expected reward or quality of the current state.
|
| 367 |
-
|
| 368 |
-
#### ActionEncoder
|
| 369 |
-
|
| 370 |
-
**Function:**
|
| 371 |
-
Encodes discrete actions (thoughts) into continuous embeddings compatible with the DynamicsNetwork.
|
| 372 |
-
|
| 373 |
-
**Mathematical Operation:**
|
| 374 |
-
```
|
| 375 |
-
\[
|
| 376 |
-
\text{Action Embedding} = \text{ActionEncoder}(\text{Action Index})
|
| 377 |
-
\]
|
| 378 |
```
|
| 379 |
-
**Explanation:**
|
| 380 |
-
Converts action indices into dense vector representations, facilitating their integration into state transition modeling.
|
| 381 |
-
|
| 382 |
-
---
|
| 383 |
-
|
| 384 |
-
### Tree of Thought (ToT)
|
| 385 |
-
|
| 386 |
-
The Tree of Thought provides a structured representation of possible thoughts/actions the model can take, organized hierarchically to enable efficient exploration during reasoning.
|
| 387 |
-
|
| 388 |
-
#### ThoughtNode
|
| 389 |
-
|
| 390 |
-
**Function:**
|
| 391 |
-
Represents a node in the Tree of Thought, corresponding to a specific action or thought.
|
| 392 |
|
| 393 |
-
**
|
|
|
|
|
|
|
|
|
|
|
|
|
| 394 |
|
| 395 |
-
|
| 396 |
-
- `children`: List of child `ThoughtNode` instances representing possible subsequent thoughts/actions.
|
| 397 |
-
- `parent`: Reference to the parent `ThoughtNode`.
|
| 398 |
|
| 399 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 400 |
|
| 401 |
-
|
| 402 |
-
```
|
| 403 |
-
\[
|
| 404 |
-
\text{ThoughtNode} = (\text{name}, \{\text{children}\})
|
| 405 |
-
\]
|
| 406 |
-
```
|
| 407 |
-
#### State
|
| 408 |
|
| 409 |
-
**
|
| 410 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 411 |
|
| 412 |
-
|
| 413 |
|
| 414 |
-
|
| 415 |
-
- `dynamics_network`: Reference to the `DynamicsNetwork` for state transitions.
|
| 416 |
-
- `action_encoder`: Reference to the `ActionEncoder` for encoding actions.
|
| 417 |
-
- `thought_node`: Reference to the current `ThoughtNode` in the Tree of Thought.
|
| 418 |
|
| 419 |
-
|
| 420 |
-
|
| 421 |
-
\[
|
| 422 |
-
\text{Next State} = \text{DynamicsNetwork}(\text{Current State}, \text{Action Embedding})
|
| 423 |
-
\]
|
| 424 |
-
\[
|
| 425 |
-
\text{New Representation} = \text{Concat}(\text{Current Representation}, \text{Next State} \rightarrow \text{unsqueeze}(1))
|
| 426 |
-
\]
|
| 427 |
```
|
| 428 |
-
**Procedure:**
|
| 429 |
|
| 430 |
-
|
| 431 |
-
|
| 432 |
-
```
|
| 433 |
-
\[
|
| 434 |
-
\text{Action Index} = \text{Index of Action}
|
| 435 |
-
\]
|
| 436 |
-
\[
|
| 437 |
-
\text{Action Embedding} = \text{ActionEncoder}(\text{Action Index})
|
| 438 |
-
\]
|
| 439 |
-
```
|
| 440 |
-
2. **State Extraction:**
|
| 441 |
|
|
|
|
|
|
|
| 442 |
```
|
| 443 |
-
\[
|
| 444 |
-
\text{Current State} = \text{representation}[:, -1, :]
|
| 445 |
-
\]
|
| 446 |
-
```
|
| 447 |
-
3. **State Transition:**
|
| 448 |
|
| 449 |
-
|
| 450 |
-
\[
|
| 451 |
-
\text{Next State Representation} = \text{DynamicsNetwork}(\text{Current State}, \text{Action Embedding})
|
| 452 |
-
\]
|
| 453 |
-
```
|
| 454 |
-
4. **Representation Update:**
|
| 455 |
|
|
|
|
|
|
|
| 456 |
```
|
| 457 |
-
\[
|
| 458 |
-
\text{New Representation} = \text{Concat}(\text{representation}, \text{Next State Representation} \times \text{unsqueeze}(1))
|
| 459 |
-
\]
|
| 460 |
-
```
|
| 461 |
-
5. **Thought Node Update:**
|
| 462 |
-
- Navigate to the child `ThoughtNode` corresponding to the applied action.
|
| 463 |
-
|
| 464 |
-
---
|
| 465 |
-
|
| 466 |
-
### Monte Carlo Tree Search (MCTS)
|
| 467 |
-
|
| 468 |
-
MCTS is an algorithm used to make optimal decisions by traversing the Tree of Thought, balancing exploration of new actions and exploitation of known rewarding actions using statistical methods.
|
| 469 |
|
| 470 |
-
|
| 471 |
-
|
| 472 |
-
**Function:**
|
| 473 |
-
Represents a node in the MCTS search tree, encapsulating a specific state in the search process.
|
| 474 |
-
|
| 475 |
-
**Attributes:**
|
| 476 |
-
|
| 477 |
-
- `state`: Current state represented by a `State` instance.
|
| 478 |
-
- `parent`: Reference to the parent `MCTSNode`.
|
| 479 |
-
- `action`: Action taken to reach this node.
|
| 480 |
-
- `children`: Dictionary mapping actions to child `MCTSNode` instances.
|
| 481 |
-
- `visit_count`: Number of times this node has been visited.
|
| 482 |
-
- `value_sum`: Cumulative value obtained from simulations passing through this node.
|
| 483 |
-
- `prior`: Prior probability of selecting this action, derived from policy logits.
|
| 484 |
-
- `entropy`: Entropy of the policy distribution at this node.
|
| 485 |
-
- `variance`: Variance of the policy distribution at this node.
|
| 486 |
-
|
| 487 |
-
**Mathematical Representation:**
|
| 488 |
-
|
| 489 |
-
Each `MCTSNode` can be considered as:
|
| 490 |
-
```
|
| 491 |
-
\[
|
| 492 |
-
\text{MCTSNode} = (\text{state}, \text{parent}, \text{action}, \{\text{children}\}, \text{visit\_count}, \text{value\_sum}, \text{prior}, \text{entropy}, \text{variance})
|
| 493 |
-
\]
|
| 494 |
-
```
|
| 495 |
-
#### MCTS Algorithm
|
| 496 |
-
|
| 497 |
-
The `MCTS` class implements the Monte Carlo Tree Search algorithm tailored to the LightBulb model's architecture.
|
| 498 |
-
|
| 499 |
-
**Key Steps:**
|
| 500 |
-
|
| 501 |
-
1. **Initialization:**
|
| 502 |
-
- Initialize with the `PredictionNetwork`, `DynamicsNetwork`, `ActionEncoder`, number of iterations (`num_iterations`), exploration constant (`exploration_constant`), beam size (`beam_size`), and number of tokens to predict (`n_tokens_predict`).
|
| 503 |
-
|
| 504 |
-
2. **Search with Beam (`search_with_beam`):**
|
| 505 |
-
- **Objective:** Explore the Tree of Thought using beam search augmented with MCTS principles.
|
| 506 |
-
- **Procedure:**
|
| 507 |
-
1. **Root Node Evaluation and Backpropagation:**
|
| 508 |
-
- Evaluate the root node to obtain policy logits and value estimates.
|
| 509 |
-
- Backpropagate the value estimate to update visit counts and value sums.
|
| 510 |
-
2. **Beam Initialization:**
|
| 511 |
-
- Start with a beam containing the root node.
|
| 512 |
-
3. **Iterative Expansion:**
|
| 513 |
-
- For each iteration up to `num_iterations`:
|
| 514 |
-
- **Candidate Collection:**
|
| 515 |
-
- For each node in the current beam:
|
| 516 |
-
- If it's a leaf node, evaluate and backpropagate its value.
|
| 517 |
-
- If it has children, select top `beam_size` actions based on UCB scores.
|
| 518 |
-
- For each selected action:
|
| 519 |
-
- Predict a sequence of `n_tokens_predict` actions.
|
| 520 |
-
- Accumulate scores, entropy, and variance.
|
| 521 |
-
- Add the candidate sequence to `all_candidates`.
|
| 522 |
-
- **Beam Pruning:**
|
| 523 |
-
- Sort all candidates based on a combined score:
|
| 524 |
-
```
|
| 525 |
-
\[
|
| 526 |
-
\text{Combined Score} = \text{Score} - 0.1 \times \text{Entropy} + 0.05 \times \text{Variance}
|
| 527 |
-
\]
|
| 528 |
-
```
|
| 529 |
-
- Retain the top `beam_size` candidates for the next iteration.
|
| 530 |
-
4. **Result Extraction:**
|
| 531 |
-
- After completing iterations, select the best action sequence from the final beam.
|
| 532 |
-
|
| 533 |
-
3. **Evaluation (`evaluate`):**
|
| 534 |
-
- **Function:** Computes the policy logits and value estimates for a given node's state.
|
| 535 |
-
- **Procedure:**
|
| 536 |
-
- Extract the last time step's state representation.
|
| 537 |
-
- Pass it through the `PredictionNetwork` to obtain policy logits and a value estimate.
|
| 538 |
-
- Convert logits to probabilities using softmax.
|
| 539 |
-
- Calculate entropy and variance of the policy distribution.
|
| 540 |
-
- Expand the node by creating child nodes based on the Tree of Thought and assign priors from policy probabilities.
|
| 541 |
-
- **Mathematical Operations:**
|
| 542 |
-
```
|
| 543 |
-
\[
|
| 544 |
-
(\text{Policy Logits}, \text{Value Estimate}) = \text{PredictionNetwork}(\text{State})
|
| 545 |
-
\]
|
| 546 |
-
\[
|
| 547 |
-
P = \text{softmax}(\text{Policy Logits})
|
| 548 |
-
\]
|
| 549 |
-
\[
|
| 550 |
-
\text{Entropy} = -\sum_{i=1}^{V} P_i \log P_i
|
| 551 |
-
\]
|
| 552 |
-
\[
|
| 553 |
-
\text{Variance} = \text{Var}(P)
|
| 554 |
-
\]
|
| 555 |
-
```
|
| 556 |
-
4. **Backpropagation (`backpropagate`):**
|
| 557 |
-
- **Function:** Updates the `visit_count` and `value_sum` for nodes along the path from the evaluated node back to the root.
|
| 558 |
-
- **Procedure:**
|
| 559 |
-
```
|
| 560 |
-
\[
|
| 561 |
-
\text{For each node in the path:} \\
|
| 562 |
-
\quad \text{node.visit\_count} \mathrel{+}= 1 \\
|
| 563 |
-
\quad \text{node.value\_sum} \mathrel{+}= \text{Value Estimate}
|
| 564 |
-
\]
|
| 565 |
-
```
|
| 566 |
-
5. **Upper Confidence Bound (UCB) Score (`ucb_score`):**
|
| 567 |
-
- **Function:** Balances exploration of less-visited nodes and exploitation of high-value nodes.
|
| 568 |
-
- **Mathematical Operation:**
|
| 569 |
|
|
|
|
|
|
|
| 570 |
```
|
| 571 |
-
\[
|
| 572 |
-
\text{UCB Score} = \text{Average Value} + \text{Exploration Term} + \text{Entropy Term} + \text{Variance Term}
|
| 573 |
-
\]
|
| 574 |
-
Where:
|
| 575 |
-
\[
|
| 576 |
-
\text{Average Value} = \frac{\text{value\_sum}}{\text{visit\_count}}
|
| 577 |
-
\]
|
| 578 |
-
\[
|
| 579 |
-
\text{Exploration Term} = \text{exploration\_constant} \times \text{prior} \times \frac{\sqrt{\text{total\_visits}}}{1 + \text{visit\_count}}
|
| 580 |
-
\]
|
| 581 |
-
\[
|
| 582 |
-
\text{Entropy Term} = -0.1 \times \text{entropy}
|
| 583 |
-
\]
|
| 584 |
-
\[
|
| 585 |
-
\text{Variance Term} = 0.05 \times \text{variance}
|
| 586 |
-
\]
|
| 587 |
-
```
|
| 588 |
-
6. **Best Action Sequence Extraction (`best_action_sequence`):**
|
| 589 |
-
- **Function:** Extracts the most promising action sequence from the MCTS tree after all iterations.
|
| 590 |
-
- **Procedure:**
|
| 591 |
-
- Traverse all possible sequences in the tree.
|
| 592 |
-
- Score each sequence based on cumulative visit counts, entropy, and variance.
|
| 593 |
-
- Select the top `beam_size` sequences and return the best one.
|
| 594 |
-
|
| 595 |
-
---
|
| 596 |
|
| 597 |
-
###
|
| 598 |
|
| 599 |
-
|
| 600 |
-
|
| 601 |
-
**Procedure:**
|
| 602 |
-
|
| 603 |
-
1. **Beam Initialization:**
|
| 604 |
-
|
| 605 |
-
```
|
| 606 |
-
- Start with a beam containing the start-of-sequence (BOS) token.
|
| 607 |
-
\[
|
| 608 |
-
\text{beam} = \left\{ \left( \text{seq} = [\text{BOS}], \text{score} = 0, \text{cum\_entropy} = 0, \text{cum\_variance} = 0 \right) \right\}
|
| 609 |
-
\]
|
| 610 |
-
```
|
| 611 |
-
2. **Iterative Expansion:**
|
| 612 |
-
- For each iteration up to
|
| 613 |
-
```
|
| 614 |
-
\( \frac{\text{max\_length}}{n\_tokens\_predict} \)
|
| 615 |
-
```
|
| 616 |
-
:
|
| 617 |
-
- For each sequence in the beam:
|
| 618 |
-
- Predict the next:
|
| 619 |
```
|
| 620 |
-
\( n\_tokens\_predict \)
|
| 621 |
|
| 622 |
-
|
| 623 |
-
tokens.
|
| 624 |
-
- Calculate their probabilities.
|
| 625 |
-
- Select top-k token sequences based on cumulative scores.
|
| 626 |
-
|
| 627 |
-
3. **Beam Pruning:**
|
| 628 |
-
- After expanding all sequences, retain only the top `beam_size` candidates based on combined scores incorporating log probabilities, entropy, and variance.
|
| 629 |
|
| 630 |
-
|
| 631 |
-
- Continue until the maximum length is reached or all sequences end with the end-of-sequence (EOS) token.
|
| 632 |
|
| 633 |
-
**Mathematical Operations:**
|
| 634 |
```
|
| 635 |
-
|
| 636 |
-
|
| 637 |
-
|
| 638 |
-
|
| 639 |
-
|
| 640 |
-
\
|
| 641 |
-
|
| 642 |
-
\text{Variance} = \text{Var}(P)
|
| 643 |
-
\]
|
| 644 |
```
|
| 645 |
-
Where \( P \) is the probability distribution over the vocabulary.
|
| 646 |
-
|
| 647 |
-
### Upper Confidence Bound (UCB) in MCTS
|
| 648 |
-
|
| 649 |
-
**Purpose:** Balances exploration of less-visited nodes and exploitation of high-value nodes during tree traversal.
|
| 650 |
-
|
| 651 |
-
**Mathematical Formulation:**
|
| 652 |
-
\[
|
| 653 |
-
\text{UCB Score} = \text{Average Value} + \text{Exploration Term} + \text{Entropy Term} + \text{Variance Term}
|
| 654 |
-
\]
|
| 655 |
-
Where:
|
| 656 |
-
\[
|
| 657 |
-
\text{Average Value} = \frac{\text{value\_sum}}{\text{visit\_count}}
|
| 658 |
-
\]
|
| 659 |
-
\[
|
| 660 |
-
\text{Exploration Term} = \text{exploration\_constant} \times \text{prior} \times \frac{\sqrt{\text{total\_visits}}}{1 + \text{visit\_count}}
|
| 661 |
-
\]
|
| 662 |
-
\[
|
| 663 |
-
\text{Entropy Term} = -0.1 \times \text{entropy}
|
| 664 |
-
\]
|
| 665 |
-
\[
|
| 666 |
-
\text{Variance Term} = 0.05 \times \text{variance}
|
| 667 |
-
\]
|
| 668 |
-
|
| 669 |
-
**Explanation:**
|
| 670 |
-
- **Average Value:** Encourages exploitation of nodes with high expected rewards.
|
| 671 |
-
- **Exploration Term:** Encourages visiting nodes with higher uncertainty (low visit counts).
|
| 672 |
-
- **Entropy and Variance Terms:** Modulate preferences based on the policy distribution's entropy and variance, promoting diverse and balanced exploration.
|
| 673 |
-
|
| 674 |
-
### Entropy and Variance Calculations
|
| 675 |
-
|
| 676 |
-
**Purpose:** Measure the uncertainty and diversity of the policy distribution, influencing the exploration-exploitation balance.
|
| 677 |
-
|
| 678 |
-
**Mathematical Formulation:**
|
| 679 |
-
\[
|
| 680 |
-
\text{Entropy} = -\sum_{i=1}^{V} P_i \log P_i
|
| 681 |
-
\]
|
| 682 |
-
\[
|
| 683 |
-
\text{Variance} = \frac{1}{V} \sum_{i=1}^{V} (P_i - \mu)^2
|
| 684 |
-
\]
|
| 685 |
-
Where:
|
| 686 |
-
- \( P_i \): Probability of action \( i \).
|
| 687 |
-
- \( \mu = \frac{1}{V} \sum_{i=1}^{V} P_i \): Mean probability.
|
| 688 |
-
|
| 689 |
-
**Explanation:**
|
| 690 |
-
- **Entropy:** Quantifies the unpredictability of the distribution. High entropy indicates a more uniform distribution, promoting exploration.
|
| 691 |
-
- **Variance:** Measures the spread of the probabilities. High variance can indicate diverse preferences among actions.
|
| 692 |
-
|
| 693 |
-
---
|
| 694 |
-
|
| 695 |
-
## Inference Workflow
|
| 696 |
-
|
| 697 |
-
The inference process in the LightBulb model can operate in two distinct modes:
|
| 698 |
-
|
| 699 |
-
1. **Without World Model:** Utilizes the Transformer with beam search to generate text directly.
|
| 700 |
-
2. **With World Model and Tree of Thought:** Employs the World Model components alongside MCTS and ToT for generating a sequence of thoughts/actions.
|
| 701 |
-
|
| 702 |
-
### Inference Modes
|
| 703 |
-
|
| 704 |
-
#### Language Model
|
| 705 |
-
|
| 706 |
-
**Procedure:**
|
| 707 |
-
|
| 708 |
-
1. **Input Processing:**
|
| 709 |
-
- Tokenize the input query.
|
| 710 |
-
- Encode tokens into embeddings via the Transformer.
|
| 711 |
-
|
| 712 |
-
2. **Beam Search Generation:**
|
| 713 |
-
- Use the Transformer's `generate_with_beam_search` method.
|
| 714 |
-
- Predict multiple tokens at each step (`n_tokens_predict`).
|
| 715 |
-
- Maintain a beam of top `beam_size` sequences based on cumulative scores.
|
| 716 |
-
|
| 717 |
-
3. **Output Decoding:**
|
| 718 |
-
- Select the best sequence based on scores.
|
| 719 |
-
- Decode token IDs back into human-readable text.
|
| 720 |
-
|
| 721 |
-
**Mathematical Operations:**
|
| 722 |
-
|
| 723 |
-
- **Beam Search Score:**
|
| 724 |
-
\[
|
| 725 |
-
\text{Score} = \sum_{t=1}^{n} \log P(\text{token}_t | \text{tokens}_{<t})
|
| 726 |
-
\]
|
| 727 |
-
|
| 728 |
-
- **Entropy and Variance:**
|
| 729 |
-
\[
|
| 730 |
-
\text{Entropy} = -\sum P_i \log P_i
|
| 731 |
-
\]
|
| 732 |
-
\[
|
| 733 |
-
\text{Variance} = \text{Var}(P)
|
| 734 |
-
\]
|
| 735 |
-
|
| 736 |
-
#### World Model and Tree of Thought
|
| 737 |
-
|
| 738 |
-
**Procedure:**
|
| 739 |
-
|
| 740 |
-
1. **Input Processing:**
|
| 741 |
-
- Tokenize the input query.
|
| 742 |
-
- Encode tokens into embeddings via the Transformer.
|
| 743 |
-
- Transform the output embeddings into a state representation using the `RepresentationNetwork`.
|
| 744 |
-
|
| 745 |
-
2. **MCTS Initialization:**
|
| 746 |
-
- Create a root `MCTSNode` with the initial state and the root `ThoughtNode` from ToT.
|
| 747 |
-
- Evaluate the root node to obtain policy logits and value estimates.
|
| 748 |
-
- Backpropagate the evaluation to update visit counts and value sums.
|
| 749 |
-
|
| 750 |
-
3. **MCTS Iterations with Beam Search:**
|
| 751 |
-
- For a predefined number of iterations (`num_iterations`):
|
| 752 |
-
- **Beam Expansion:**
|
| 753 |
-
- For each node in the current beam:
|
| 754 |
-
- If it's a leaf node, evaluate and backpropagate its value.
|
| 755 |
-
- Select top `beam_size` actions based on UCB scores.
|
| 756 |
-
- For each selected action:
|
| 757 |
-
- Predict a sequence of `n_tokens_predict` actions.
|
| 758 |
-
- Accumulate scores, entropy, and variance.
|
| 759 |
-
- Add the candidate sequence to `all_candidates`.
|
| 760 |
-
- **Beam Pruning:**
|
| 761 |
-
- Sort all candidates based on a combined score:
|
| 762 |
-
\[
|
| 763 |
-
\text{Combined Score} = \text{Score} - 0.1 \times \text{Entropy} + 0.05 \times \text{Variance}
|
| 764 |
-
\]
|
| 765 |
-
- Retain the top `beam_size` candidates for the next iteration.
|
| 766 |
-
|
| 767 |
-
4. **Output Generation:**
|
| 768 |
-
- After completing iterations, select the best action sequence from the final beam.
|
| 769 |
-
- Return the sequence of actions (thoughts) as the output.
|
| 770 |
|
| 771 |
-
**Mathematical Operations:**
|
| 772 |
|
| 773 |
-
|
| 774 |
-
\[
|
| 775 |
-
\text{UCB Score} = \frac{\text{value\_sum}}{\text{visit\_count}} + \text{exploration\_constant} \times \text{prior} \times \frac{\sqrt{\text{total\_visits}}}{1 + \text{visit\_count}} - 0.1 \times \text{entropy} + 0.05 \times \text{variance}
|
| 776 |
-
\]
|
| 777 |
|
| 778 |
-
|
| 779 |
-
\[
|
| 780 |
-
\text{Combined Score} = \text{Score} - 0.1 \times \text{Entropy} + 0.05 \times \text{Variance}
|
| 781 |
-
\]
|
| 782 |
|
| 783 |
---
|
| 784 |
|
| 785 |
-
|
| 786 |
-
|
| 787 |
-
### Mode: Without World Model
|
| 788 |
-
|
| 789 |
-
**Step 1: Input Tokenization and Encoding**
|
| 790 |
-
|
| 791 |
-
1. **Tokenization:**
|
| 792 |
-
\[
|
| 793 |
-
\text{input\_ids} = \text{tokenizer.encode(query, return\_tensors='pt')}
|
| 794 |
-
\]
|
| 795 |
-
- Shape: \((\text{batch\_size}=1, \text{seq\_len})\)
|
| 796 |
-
|
| 797 |
-
2. **Encoding via Transformer:**
|
| 798 |
-
\[
|
| 799 |
-
\text{transformer\_output} = \text{model\_transformer}(input\_ids, input\_ids)
|
| 800 |
-
\]
|
| 801 |
-
- Shape: \((\text{batch\_size}=1, \text{seq\_len}, \text{d\_model})\)
|
| 802 |
-
|
| 803 |
-
**Step 2: Beam Search Generation**
|
| 804 |
-
|
| 805 |
-
1. **Beam Search Initialization:**
|
| 806 |
-
\[
|
| 807 |
-
\text{beam} = \left\{ \left( \text{seq} = [\text{BOS}], \text{score} = 0, \text{cum\_entropy} = 0, \text{cum\_variance} = 0 \right) \right\}
|
| 808 |
-
\]
|
| 809 |
-
|
| 810 |
-
2. **Iterative Expansion:**
|
| 811 |
-
- For each iteration up to \(\frac{\text{max\_length}}{n\_tokens\_predict}\):
|
| 812 |
-
- For each sequence in the beam:
|
| 813 |
-
- If the last token is EOS, retain the sequence.
|
| 814 |
-
- Else, predict the next \( n\_tokens\_predict \) tokens.
|
| 815 |
-
- Calculate probabilities, entropy, and variance.
|
| 816 |
-
- Select top-k tokens for each position based on beam size.
|
| 817 |
-
- Generate all possible continuations.
|
| 818 |
-
- **Score Calculation:**
|
| 819 |
-
\[
|
| 820 |
-
\text{new\_score} = \text{score} + \sum_{t=1}^{n} \log P(\text{token}_t | \text{tokens}_{<t})
|
| 821 |
-
\]
|
| 822 |
-
- **Entropy and Variance Accumulation:**
|
| 823 |
-
\[
|
| 824 |
-
\text{new\_entropy} = \text{cum\_entropy} + \sum_{t=1}^{n} \text{Entropy}_t
|
| 825 |
-
\]
|
| 826 |
-
\[
|
| 827 |
-
\text{new\_variance} = \text{cum\_variance} + \sum_{t=1}^{n} \text{Variance}_t
|
| 828 |
-
\]
|
| 829 |
-
- **Candidate Aggregation:**
|
| 830 |
-
- Append new sequences to `all_candidates`.
|
| 831 |
-
|
| 832 |
-
3. **Beam Pruning:**
|
| 833 |
-
- Sort `all_candidates` based on combined scores.
|
| 834 |
-
- Retain the top `beam_size` candidates for the next iteration.
|
| 835 |
-
|
| 836 |
-
**Step 3: Output Decoding**
|
| 837 |
-
|
| 838 |
-
1. **Select Best Sequence:**
|
| 839 |
-
\[
|
| 840 |
-
\text{best\_sequence} = \text{beam}[0][0]
|
| 841 |
-
\]
|
| 842 |
-
|
| 843 |
-
2. **Decode Tokens to Text:**
|
| 844 |
-
\[
|
| 845 |
-
\text{generated\_text} = \text{tokenizer.decode(best\_sequence, skip\_special\_tokens=True)}
|
| 846 |
-
\]
|
| 847 |
-
|
| 848 |
-
---
|
| 849 |
-
|
| 850 |
-
### Mode: With World Model and Tree of Thought
|
| 851 |
-
|
| 852 |
-
**Step 1: Input Tokenization and Encoding**
|
| 853 |
-
|
| 854 |
-
1. **Tokenization:**
|
| 855 |
-
\[
|
| 856 |
-
\text{input\_ids} = \text{tokenizer.encode(query, return\_tensors='pt')}
|
| 857 |
-
\]
|
| 858 |
-
- Shape: \((\text{batch\_size}=1, \text{seq\_len})\)
|
| 859 |
-
|
| 860 |
-
2. **Encoding via Transformer:**
|
| 861 |
-
\[
|
| 862 |
-
\text{transformer\_output} = \text{model\_transformer}(input\_ids, input\_ids)
|
| 863 |
-
\]
|
| 864 |
-
- Shape: \((\text{batch\_size}=1, \text{seq\_len}, \text{d\_model})\)
|
| 865 |
-
|
| 866 |
-
3. **State Representation:**
|
| 867 |
-
\[
|
| 868 |
-
\text{initial\_representation} = \text{RepresentationNetwork}(\text{transformer\_output})[:, -1, :].unsqueeze(1)
|
| 869 |
-
\]
|
| 870 |
-
- Shape: \((\text{batch\_size}=1, 1, \text{state\_dim})\)
|
| 871 |
-
|
| 872 |
-
4. **State Initialization:**
|
| 873 |
-
\[
|
| 874 |
-
\text{initial\_state} = \text{State}(\text{representation}=\text{initial\_representation}, \text{dynamics\_network}=dynamics\_network, \text{action\_encoder}=action\_encoder, \text{thought\_node}=root\_thought\_node)
|
| 875 |
-
\]
|
| 876 |
-
|
| 877 |
-
**Step 2: MCTS Initialization and Root Node Evaluation**
|
| 878 |
-
|
| 879 |
-
1. **MCTS Instance Creation:**
|
| 880 |
-
\[
|
| 881 |
-
\text{mcts} = \text{MCTS}(\text{prediction\_network}, \text{dynamics\_network}, \text{action\_encoder}, \text{num\_iterations}=mcts\_iterations, \text{exploration\_constant}=exploration\_constant)
|
| 882 |
-
\]
|
| 883 |
-
|
| 884 |
-
2. **Root Node Creation:**
|
| 885 |
-
\[
|
| 886 |
-
\text{root\_node} = \text{MCTSNode}(\text{state}=\text{initial\_state}, \text{thought\_node}=root\_thought\_node)
|
| 887 |
-
\]
|
| 888 |
-
|
| 889 |
-
3. **Root Node Evaluation:**
|
| 890 |
-
\[
|
| 891 |
-
\text{value\_estimate} = \text{mcts.evaluate}(\text{root\_node})
|
| 892 |
-
\]
|
| 893 |
-
|
| 894 |
-
4. **Backpropagation:**
|
| 895 |
-
\[
|
| 896 |
-
\text{mcts.backpropagate}(\text{root\_node}, \text{value\_estimate})
|
| 897 |
-
\]
|
| 898 |
-
|
| 899 |
-
**Step 3: MCTS Iterations with Beam Search**
|
| 900 |
-
|
| 901 |
-
1. **Beam Initialization:**
|
| 902 |
-
\[
|
| 903 |
-
\text{beam} = \left\{ \left( \text{root\_node}, \text{score} = 0, \text{cum\_entropy} = 0, \text{cum\_variance} = 0, \text{action\_sequence} = [] \right) \right\}
|
| 904 |
-
\]
|
| 905 |
-
|
| 906 |
-
2. **Iterative Expansion:**
|
| 907 |
-
- For each iteration up to `num_iterations`:
|
| 908 |
-
- **Candidate Collection:**
|
| 909 |
-
- For each node in the current beam:
|
| 910 |
-
- **Leaf Evaluation:**
|
| 911 |
-
- If `node.is_leaf()`:
|
| 912 |
-
\[
|
| 913 |
-
\text{value\_estimate} = \text{mcts.evaluate}(\text{node})
|
| 914 |
-
\]
|
| 915 |
-
\[
|
| 916 |
-
\text{mcts.backpropagate}(\text{node}, \text{value\_estimate})
|
| 917 |
-
\]
|
| 918 |
-
- **Child Selection:**
|
| 919 |
-
- If `node.children` is not empty:
|
| 920 |
-
- Calculate `total_visits`:
|
| 921 |
-
\[
|
| 922 |
-
\text{total\_visits} = \sum_{\text{child} \in \text{node.children}} \text{child.visit\_count}
|
| 923 |
-
\]
|
| 924 |
-
- Select top `beam_size` actions based on UCB scores:
|
| 925 |
-
\[
|
| 926 |
-
\text{sorted\_children} = \text{sorted}(\text{node.children.items()}, \text{key}=\lambda \text{item}: \text{item}[1].ucb\_score(\text{total\_visits}, \text{exploration\_constant}), \text{reverse}=True)[:\text{beam\_size}]
|
| 927 |
-
\]
|
| 928 |
-
- **Action Sequence Prediction:**
|
| 929 |
-
- For each selected action:
|
| 930 |
-
- Initialize `current_node`, `current_sequence`, `current_score`, `current_entropy`, `current_variance`.
|
| 931 |
-
- **Multi-Token Prediction:**
|
| 932 |
-
- For each step in `n_tokens_predict`:
|
| 933 |
-
- If `current_node.is_leaf()`:
|
| 934 |
-
\[
|
| 935 |
-
\text{value\_estimate} = \text{mcts.evaluate}(\text{current\_node})
|
| 936 |
-
\]
|
| 937 |
-
\[
|
| 938 |
-
\text{mcts.backpropagate}(\text{current\_node}, \text{value\_estimate})
|
| 939 |
-
\]
|
| 940 |
-
- If `current_node.children` is empty, break.
|
| 941 |
-
- Calculate `total_visits` for the new node.
|
| 942 |
-
- Select the action with the highest UCB score:
|
| 943 |
-
\[
|
| 944 |
-
(\text{next\_action}, \text{next\_node}) = \text{max}(\text{current\_node.children.items()}, \text{key}=\lambda \text{item}: \text{item}[1].ucb\_score(\text{total\_visits}, \text{exploration\_constant}))
|
| 945 |
-
\]
|
| 946 |
-
- **Score Update:**
|
| 947 |
-
\[
|
| 948 |
-
\text{current\_score} += \frac{\text{next\_node.value\_sum}}{\text{next\_node.visit\_count}} \quad \text{if} \quad \text{next\_node.visit\_count} > 0 \quad \text{else} \quad 0
|
| 949 |
-
\]
|
| 950 |
-
- **Entropy and Variance Update:**
|
| 951 |
-
\[
|
| 952 |
-
\text{current\_entropy} += \text{next\_node.entropy}
|
| 953 |
-
\]
|
| 954 |
-
\[
|
| 955 |
-
\text{current\_variance} += \text{next\_node.variance}
|
| 956 |
-
\]
|
| 957 |
-
- Append `next_action` to `current_sequence`.
|
| 958 |
-
- Update `current_node` to `next_node`.
|
| 959 |
-
- **Candidate Aggregation:**
|
| 960 |
-
\[
|
| 961 |
-
\text{all\_candidates.append}((\text{current\_node}, \text{current\_score}, \text{current\_entropy}, \text{current\_variance}, \text{current\_sequence}))
|
| 962 |
-
\]
|
| 963 |
-
|
| 964 |
-
- **Beam Pruning:**
|
| 965 |
-
\[
|
| 966 |
-
\text{beam} = \text{sorted}(\text{all\_candidates}, \text{key}=\lambda x: x[1] - 0.1 \times x[2] + 0.05 \times x[3], \text{reverse}=True)[:\text{beam\_size}]
|
| 967 |
-
\]
|
| 968 |
-
|
| 969 |
-
3. **Termination:**
|
| 970 |
-
- Stop early if no candidates remain or all beams have reached terminal nodes.
|
| 971 |
-
|
| 972 |
-
4. **Result Extraction:**
|
| 973 |
-
\[
|
| 974 |
-
\text{best\_sequence} = \text{beam}[0][4]
|
| 975 |
-
\]
|
| 976 |
-
- Return `best_sequence` as the generated sequence of actions (thoughts).
|
| 977 |
-
|
| 978 |
-
---
|
| 979 |
-
|
| 980 |
-
## Integration of Components
|
| 981 |
-
|
| 982 |
-
The inference process seamlessly integrates multiple components to facilitate advanced reasoning:
|
| 983 |
-
|
| 984 |
-
1. **Transformer for Sequence Encoding and Generation:**
|
| 985 |
-
- Processes input sequences and generates embeddings.
|
| 986 |
-
- Facilitates beam search for direct text generation.
|
| 987 |
-
|
| 988 |
-
2. **World Model for State Representation and Dynamics:**
|
| 989 |
-
- `RepresentationNetwork` encodes transformer outputs into state representations.
|
| 990 |
-
- `DynamicsNetwork` predicts state transitions based on actions.
|
| 991 |
-
- `PredictionNetwork` provides policy logits and value estimates.
|
| 992 |
-
- `ActionEncoder` encodes actions into embeddings for state transitions.
|
| 993 |
-
|
| 994 |
-
3. **Tree of Thought for Structured Reasoning:**
|
| 995 |
-
- Organizes possible thoughts/actions hierarchically.
|
| 996 |
-
- Enables systematic exploration of reasoning paths.
|
| 997 |
-
|
| 998 |
-
4. **Monte Carlo Tree Search for Strategic Exploration:**
|
| 999 |
-
- Utilizes ToT to explore potential reasoning paths.
|
| 1000 |
-
- Balances exploration and exploitation using UCB scores.
|
| 1001 |
-
- Incorporates beam search with multi-token prediction to handle complex, multi-step actions.
|
| 1002 |
-
|
| 1003 |
-
**Workflow Integration:**
|
| 1004 |
-
|
| 1005 |
-
- **Input Processing:** Input query is tokenized and encoded via the Transformer.
|
| 1006 |
-
- **State Representation:** Encoded inputs are transformed into initial state representations.
|
| 1007 |
-
- **MCTS Integration:** MCTS uses the World Model components to explore and evaluate possible thought sequences within the Tree of Thought.
|
| 1008 |
-
- **Beam Search:** Multi-token beam search within MCTS ensures diverse and coherent exploration of actions.
|
| 1009 |
-
- **Output Generation:** The best sequence of thoughts/actions is extracted and returned as the inference result.
|
| 1010 |
-
|
| 1011 |
-
---
|
| 1012 |
-
|
| 1013 |
-
I am utilising Trees of Thought as a structure of how to structure sets of policies, and sequences of actions. These Tree structures provide the World Model a general thought structure and pattern, similarly to how humans create thought patterns for solving certain problems (e.g. understand, describe, analyse, etc).
|
| 1014 |
-
|
| 1015 |
-
Here are some example Trees of Thought:
|
| 1016 |
-
graph TD
|
| 1017 |
-
A[Problem-Solving Process] --> B[Problem Identification]
|
| 1018 |
-
A --> C[Problem Analysis]
|
| 1019 |
-
A --> D[Solution Generation]
|
| 1020 |
-
A --> E[Implementation]
|
| 1021 |
-
A --> F[Evaluation and Adjustment]
|
| 1022 |
-
B --> B1[Define the Problem]
|
| 1023 |
-
B --> B2[Identify Stakeholders]
|
| 1024 |
-
B --> B3[Determine Constraints]
|
| 1025 |
-
B --> B4[Recognize Problem Type]
|
| 1026 |
-
B --> B5[Historical Context]
|
| 1027 |
-
C --> C1[Root Cause Analysis]
|
| 1028 |
-
C --> C2[System Mapping]
|
| 1029 |
-
C --> C3[Data Collection]
|
| 1030 |
-
C --> C4[Impact Assessment]
|
| 1031 |
-
C --> C5[Theoretical Framework]
|
| 1032 |
-
D --> D1[Creative Problem Solving]
|
| 1033 |
-
D --> D2[Analytical Approach]
|
| 1034 |
-
D --> D3[Mathematical Computation]
|
| 1035 |
-
D --> D4[Decision Making]
|
| 1036 |
-
E --> E1[Action Planning]
|
| 1037 |
-
E --> E2[Resource Allocation]
|
| 1038 |
-
E --> E3[Change Management]
|
| 1039 |
-
F --> F1[Verification]
|
| 1040 |
-
F --> F2[Performance Metrics]
|
| 1041 |
-
F --> F3[Feedback Loops]
|
| 1042 |
-
F --> F4[Continuous Improvement]
|
| 1043 |
-
C3 --> C3a[Quantitative Data]
|
| 1044 |
-
C3 --> C3b[Qualitative Data]
|
| 1045 |
-
C3 --> C3c[Data Validation]
|
| 1046 |
-
D1 --> D1a[Divergent Thinking]
|
| 1047 |
-
D1 --> D1b[Convergent Thinking]
|
| 1048 |
-
D1 --> D1c[Lateral Thinking]
|
| 1049 |
-
D2 --> D2a[Logical Reasoning]
|
| 1050 |
-
D2 --> D2b[Critical Analysis]
|
| 1051 |
-
D2 --> D2c[Systems Thinking]
|
| 1052 |
-
D3 --> D3a[Basic Operations]
|
| 1053 |
-
D3 --> D3b[Advanced Operations]
|
| 1054 |
-
D3 --> D3c[Computational Methods]
|
| 1055 |
-
D4 --> D4a[Decision Trees]
|
| 1056 |
-
D4 --> D4b[Multi-Criteria Analysis]
|
| 1057 |
-
D4 --> D4c[Probabilistic Reasoning]
|
| 1058 |
-
G[Cross-Cutting Considerations] --> G1[Ethical Framework]
|
| 1059 |
-
G --> G2[Stakeholder Management]
|
| 1060 |
-
G --> G3[Interdisciplinary Connections]
|
| 1061 |
-
G --> G4[Technological Integration]
|
| 1062 |
-
G --> G5[Emotional Intelligence]
|
| 1063 |
-
G --> G6[Collaborative Problem Solving]
|
| 1064 |
-
G1 --> G1a[Value-based Decision Making]
|
| 1065 |
-
G1 --> G1b[Long-term Consequences]
|
| 1066 |
-
G2 --> G2a[Direct Stakeholders]
|
| 1067 |
-
G2 --> G2b[Indirect Stakeholders]
|
| 1068 |
-
G2 --> G2c[Conflicting Interests]
|
| 1069 |
-
G3 --> G3a[Related Fields]
|
| 1070 |
-
G3 --> G3b[Cross-disciplinary Impact]
|
| 1071 |
-
G4 --> G4a[AI-assisted Problem Solving]
|
| 1072 |
-
G4 --> G4b[Data-driven Insights]
|
| 1073 |
-
G4 --> G4c[Digital Collaboration Tools]
|
| 1074 |
-
G5 --> G5a[Self-Awareness]
|
| 1075 |
-
G5 --> G5b[Empathy]
|
| 1076 |
-
G5 --> G5c[Stress Management]
|
| 1077 |
-
G6 --> G6a[Team Dynamics]
|
| 1078 |
-
G6 --> G6b[Communication Strategies]
|
| 1079 |
-
G6 --> G6c[Conflict Resolution]
|
| 1080 |
-
H[Computational Considerations] --> H1[CPU Operations]
|
| 1081 |
-
H --> H2[GPU Parallelization]
|
| 1082 |
-
H --> H3[Floating-Point Precision]
|
| 1083 |
-
I[Order of Operations] --> I1[Parentheses]
|
| 1084 |
-
I --> I2[Exponents]
|
| 1085 |
-
I --> I3[Multiplication and Division]
|
| 1086 |
-
I --> I4[Addition and Subtraction]
|
| 1087 |
-
J[Critical Thinking] --> J1[Assumptions Questioning]
|
| 1088 |
-
J --> J2[Bias Recognition]
|
| 1089 |
-
K[Future Perspective] --> K1[Short-term Projections]
|
| 1090 |
-
K --> K2[Long-term Scenarios]
|
| 1091 |
-
K --> K3[Potential Impacts]
|
| 1092 |
-
L[Learning and Adaptation] --> L1[Reflective Practice]
|
| 1093 |
-
L --> L2[Knowledge Transfer]
|
| 1094 |
-
L --> L3[Adaptive Problem Solving]
|
| 1095 |
-
|
| 1096 |
-
|
| 1097 |
-
graph TD
|
| 1098 |
-
A[Meta-Cognitive Strategies] --> B[Creative Problem Solving]
|
| 1099 |
-
A --> C[Systems Thinking]
|
| 1100 |
-
A --> D[Decision Making]
|
| 1101 |
-
A --> E[Emotional Intelligence]
|
| 1102 |
-
A --> F[Collaborative Problem Solving]
|
| 1103 |
-
B --> B1[Divergent Thinking]
|
| 1104 |
-
B --> B2[Convergent Thinking]
|
| 1105 |
-
B --> B3[Lateral Thinking]
|
| 1106 |
-
C --> C1[Holistic Perspective]
|
| 1107 |
-
C --> C2[Feedback Loops]
|
| 1108 |
-
C --> C3[Emergent Properties]
|
| 1109 |
-
D --> D1[Decision Trees]
|
| 1110 |
-
D --> D2[Multi-Criteria Decision Analysis]
|
| 1111 |
-
D --> D3[Probabilistic Reasoning]
|
| 1112 |
-
E --> E1[Self-Awareness]
|
| 1113 |
-
E --> E2[Empathy]
|
| 1114 |
-
E --> E3[Stress Management]
|
| 1115 |
-
F --> F1[Team Dynamics]
|
| 1116 |
-
F --> F2[Communication Strategies]
|
| 1117 |
-
F --> F3[Conflict Resolution]
|
| 1118 |
-
G[Learning and Adaptation]
|
| 1119 |
-
A --> G
|
| 1120 |
-
G --> G1[Reflective Practice]
|
| 1121 |
-
G --> G2[Knowledge Transfer]
|
| 1122 |
-
G --> G3[Adaptive Problem Solving]
|
| 1123 |
-
H[Ethical Framework]
|
| 1124 |
-
A --> H
|
| 1125 |
-
H --> H1[Value-based Decision Making]
|
| 1126 |
-
H --> H2[Stakeholder Analysis]
|
| 1127 |
-
H --> H3[Long-term Consequences]
|
| 1128 |
-
I[Technological Integration]
|
| 1129 |
-
A --> I
|
| 1130 |
-
I --> I1[AI-assisted Problem Solving]
|
| 1131 |
-
I --> I2[Data-driven Insights]
|
| 1132 |
-
I --> I3[Digital Collaboration Tools]
|
| 1133 |
-
|
| 1134 |
-
|
| 1135 |
-
## Requirements
|
| 1136 |
-
|
| 1137 |
-
This code requires:
|
| 1138 |
-
- Python 3.7+
|
| 1139 |
-
- `torch>=1.7.1`
|
| 1140 |
-
- `transformers`
|
| 1141 |
-
- `datasets`
|
| 1142 |
-
- `argparse`
|
| 1143 |
-
|
| 1144 |
-
|
| 1145 |
-
## Citation
|
| 1146 |
-
|
| 1147 |
-
If you use this model in your research, please cite the author.
|
|
|
|
| 1 |
---
|
| 2 |
license: apache-2.0
|
| 3 |
|
| 4 |
+
---
|
| 5 |
+
---
|
| 6 |
+
license: apache-2.0
|
| 7 |
---
|
| 8 |
|
| 9 |
+
# Model Card for LightBulb
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
+
## Overview
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
+
**LightBulb** is an advanced framework designed to train and utilize language models and autonomous web search agents. It integrates hierarchical reinforcement learning, Monte Carlo Tree Search (MCTS), and Tree of Thought (ToT) architectures to enable sophisticated reasoning and decision-making capabilities. The framework supports both training and inference for language models, web search agents, and comprehensive world models.
|
| 14 |
|
| 15 |
+
## Installation
|
| 16 |
|
| 17 |
+
To install the necessary dependencies, run:
|
| 18 |
|
|
|
|
| 19 |
```bash
|
| 20 |
+
pip install huggingface_hub torch transformers datasets argparse
|
| 21 |
```
|
| 22 |
|
| 23 |
+
## Getting Started
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
|
| 25 |
+
### Download the Repository
|
|
|
|
|
|
|
|
|
|
| 26 |
|
| 27 |
+
Use the `huggingface_hub` to download the repository:
|
|
|
|
|
|
|
|
|
|
| 28 |
|
| 29 |
+
```python
|
| 30 |
+
from huggingface_hub import snapshot_download
|
|
|
|
| 31 |
|
| 32 |
+
# Download the repository
|
| 33 |
+
repo_path = snapshot_download("RobbiePasquale/lightbulb")
|
|
|
|
|
|
|
| 34 |
|
| 35 |
+
print(f"Repository downloaded to: {repo_path}")
|
| 36 |
+
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
|
| 38 |
+
## Main Features
|
| 39 |
|
| 40 |
+
LightBulb provides six primary functionalities, each accessible via the `main_menu.py` script using command-line arguments.
|
|
|
|
|
|
|
| 41 |
|
| 42 |
+
### 1. Train a Web Search Agent
|
| 43 |
|
| 44 |
+
**Description:**
|
| 45 |
+
Trains an autonomous web search agent that navigates the web, gathers relevant content, and learns to summarize and generate responses based on user queries.
|
| 46 |
|
| 47 |
+
**Usage:**
|
| 48 |
```bash
|
| 49 |
python main_menu.py --task train_agent
|
| 50 |
```
|
| 51 |
|
| 52 |
+
**Key Components:**
|
| 53 |
+
- **Hierarchical Reinforcement Learning (HRL):** Manages high-level (Manager) and low-level (Worker) decision-making.
|
| 54 |
+
- **Monte Carlo Tree Search (MCTS):** Guides the agent through complex decision trees.
|
| 55 |
+
- **RAGSummarizer:** Summarizes retrieved web content.
|
| 56 |
+
- **Knowledge Base:** Stores and retrieves information to inform future queries.
|
| 57 |
|
| 58 |
+
### 2. Use a Web Search Agent (Inference)
|
| 59 |
|
| 60 |
+
**Description:**
|
| 61 |
+
Utilizes the trained web search agent to process queries, perform web searches, and generate summarized responses.
|
| 62 |
|
| 63 |
+
**Usage:**
|
|
|
|
| 64 |
```bash
|
| 65 |
python main_menu.py --task test_agent
|
| 66 |
```
|
| 67 |
|
| 68 |
+
**Options:**
|
| 69 |
+
- **Interactive Mode:**
|
| 70 |
+
```bash
|
| 71 |
+
python main_menu.py --task test_agent
|
| 72 |
+
```
|
| 73 |
+
- **Single Query Mode:**
|
| 74 |
+
```bash
|
| 75 |
+
python main_menu.py --task test_agent --query "Your query here"
|
| 76 |
+
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
|
| 78 |
+
### 3. Train a Language Model
|
| 79 |
|
| 80 |
+
**Description:**
|
| 81 |
+
Trains a Language Model (LLM) and World Model using datasets from Hugging Face, enabling the model to handle complex reasoning and long sequences.
|
| 82 |
|
| 83 |
+
**Usage:**
|
| 84 |
+
```bash
|
| 85 |
+
python main_menu.py --task train_llm_world --model_name gpt2 --dataset_name wikitext --num_epochs 5 --batch_size 8 --max_length 256
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
```
|
|
|
|
| 87 |
|
| 88 |
+
**Key Arguments:**
|
| 89 |
+
- `--model_name`: Pretrained model (e.g., `gpt2`, `bert`).
|
| 90 |
+
- `--dataset_name`: Dataset from Hugging Face (e.g., `wikitext`).
|
| 91 |
+
- `--num_epochs`: Number of training epochs.
|
| 92 |
+
- `--batch_size`: Number of samples per batch.
|
| 93 |
+
- `--max_length`: Maximum sequence length.
|
| 94 |
|
| 95 |
+
### 4. Inference Using Language Model with Multi-Token Prediction, Beam Search, and MCTS
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
|
| 97 |
+
**Description:**
|
| 98 |
+
Generates responses using the trained language model, leveraging multi-token prediction, beam search, and MCTS for enhanced coherence and strategic reasoning.
|
| 99 |
|
| 100 |
+
**Usage:**
|
| 101 |
+
```bash
|
| 102 |
+
python main_menu.py --task inference_llm --query "Your query here"
|
| 103 |
```
|
| 104 |
|
| 105 |
+
**Process:**
|
| 106 |
+
1. **Multi-Token Prediction:** Predicts multiple tokens at each step to improve generation speed.
|
| 107 |
+
2. **Beam Search:** Maintains multiple candidate sequences to ensure diverse and high-quality outputs.
|
| 108 |
+
3. **MCTS Integration:** Uses MCTS to evaluate and select the most promising token sequences based on policy and value estimates.
|
| 109 |
|
| 110 |
+
### 5. Train a Language World Model
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
|
| 112 |
+
**Description:**
|
| 113 |
+
Develops a comprehensive World Model that encapsulates state representations, dynamics, and prediction networks to simulate and predict state transitions within the Tree of Thought framework.
|
| 114 |
|
| 115 |
+
**Usage:**
|
| 116 |
+
```bash
|
| 117 |
+
python main_menu.py --task train_world_model --additional_args
|
| 118 |
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
|
| 120 |
+
**Key Components:**
|
| 121 |
+
- **Representation Network:** Encodes Transformer outputs into state representations.
|
| 122 |
+
- **Dynamics Network:** Predicts next states based on current states and actions.
|
| 123 |
+
- **Prediction Network:** Generates policy logits and value estimates.
|
| 124 |
+
- **Action Encoder:** Encodes actions into embeddings for state transitions.
|
| 125 |
|
| 126 |
+
### 6. Inference with Language World Model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
|
| 128 |
+
**Description:**
|
| 129 |
+
Utilizes the trained World Model to perform advanced reasoning and generate responses based on structured thought processes and state simulations.
|
| 130 |
|
| 131 |
+
**Usage:**
|
| 132 |
+
```bash
|
| 133 |
+
python main_menu.py --task inference_world_model --query "Your query here"
|
|
|
|
|
|
|
| 134 |
```
|
| 135 |
|
| 136 |
+
**Features:**
|
| 137 |
+
- **Tree of Thought (ToT):** Structures reasoning paths hierarchically.
|
| 138 |
+
- **Beam Search with MCTS:** Enhances decision-making by balancing exploration and exploitation.
|
| 139 |
+
- **Integration with Knowledge Base:** Leverages stored information for informed responses.
|
| 140 |
|
| 141 |
+
### 7. Inference with World Model, Tree of Thought, and Multi-Token Beam Search
|
| 142 |
|
| 143 |
+
**Description:**
|
| 144 |
+
Executes inference using the World Model integrated with ToT and multi-token beam search for highly coherent and contextually rich outputs.
|
| 145 |
|
| 146 |
+
**Usage:**
|
| 147 |
+
```bash
|
| 148 |
+
python main_menu.py --task advanced_inference --query "Your complex query here"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 149 |
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 150 |
|
| 151 |
+
**Process:**
|
| 152 |
+
1. **State Initialization:** Converts input queries into state representations.
|
| 153 |
+
2. **MCTS with Beam Search:** Explores multiple reasoning paths simultaneously.
|
| 154 |
+
3. **Thought Sequence Generation:** Produces a sequence of interconnected thoughts/actions.
|
| 155 |
+
4. **Final Response Generation:** Synthesizes the best thought path into a coherent response.
|
| 156 |
|
| 157 |
+
## General Arguments
|
|
|
|
|
|
|
| 158 |
|
| 159 |
+
| Argument | Required | Description | Default |
|
| 160 |
+
|--------------------|----------|--------------------------------------------------------------------------------------------------|---------------------|
|
| 161 |
+
| `--task` | Yes | Specifies the task to run (`train_llm_world`, `train_agent`, `test_agent`, etc.). | None |
|
| 162 |
+
| `--model_name` | No | Pretrained model name for LLM (`gpt2`, `bert`, etc.) or a custom model path. | `gpt2` |
|
| 163 |
+
| `--dataset_name` | No | Name of the dataset from Hugging Face for training the LLM and World Model (e.g., `wikitext`). | `wikitext` |
|
| 164 |
+
| `--dataset_config` | No | Configuration name for the dataset. | `wikitext-2-raw-v1` |
|
| 165 |
+
| `--batch_size` | No | Number of samples per batch during training. | `4` |
|
| 166 |
+
| `--num_epochs` | No | Number of training epochs. | `3` |
|
| 167 |
+
| `--max_length` | No | Maximum sequence length for training/inference. | `128` |
|
| 168 |
+
| `--mode` | No | Mode for LLM and World Model (`train`, `inference`). | `train` |
|
| 169 |
+
| `--query` | No | Query input for `test_agent` when running a single query. | `''` (empty) |
|
| 170 |
|
| 171 |
+
## Requirements
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 172 |
|
| 173 |
+
- **Python:** 3.7+
|
| 174 |
+
- **Libraries:**
|
| 175 |
+
- `torch>=1.7.1`
|
| 176 |
+
- `transformers`
|
| 177 |
+
- `datasets`
|
| 178 |
+
- `argparse`
|
| 179 |
+
- `huggingface_hub`
|
| 180 |
|
| 181 |
+
## Usage Examples
|
| 182 |
|
| 183 |
+
### Training the Language Model and World Model
|
|
|
|
|
|
|
|
|
|
| 184 |
|
| 185 |
+
```bash
|
| 186 |
+
python main_menu.py --task train_llm_world --model_name gpt2 --dataset_name wikitext --num_epochs 5 --batch_size 8 --max_length 256
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 187 |
```
|
|
|
|
| 188 |
|
| 189 |
+
### Training the Web Search Agent
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 190 |
|
| 191 |
+
```bash
|
| 192 |
+
python main_menu.py --task train_agent
|
| 193 |
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 194 |
|
| 195 |
+
### Testing the Web Search Agent in Interactive Mode
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 196 |
|
| 197 |
+
```bash
|
| 198 |
+
python main_menu.py --task test_agent
|
| 199 |
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 200 |
|
| 201 |
+
### Testing the Web Search Agent with a Single Query
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 202 |
|
| 203 |
+
```bash
|
| 204 |
+
python main_menu.py --task test_agent --query "What are the impacts of renewable energy on global sustainability?"
|
| 205 |
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 206 |
|
| 207 |
+
### Advanced Inference with World Model and Tree of Thought
|
| 208 |
|
| 209 |
+
```bash
|
| 210 |
+
python main_menu.py --task advanced_inference --query "Analyze the economic effects of artificial intelligence in the next decade."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 211 |
```
|
|
|
|
| 212 |
|
| 213 |
+
## Citation
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 214 |
|
| 215 |
+
If you use LightBulb in your research, please cite the author:
|
|
|
|
| 216 |
|
|
|
|
| 217 |
```
|
| 218 |
+
@misc{RobbiePasquale_lightbulb,
|
| 219 |
+
author = {Robbie Pasquale},
|
| 220 |
+
title = {LightBulb: An Autonomous Web Search and Language Model Framework},
|
| 221 |
+
year = {2024},
|
| 222 |
+
publisher = {Huggingface},
|
| 223 |
+
howpublished = {\url{https://huggingface.co/RobbiePasquale/lightbulb}},
|
| 224 |
+
}
|
|
|
|
|
|
|
| 225 |
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 226 |
|
|
|
|
| 227 |
|
| 228 |
+
## License
|
|
|
|
|
|
|
|
|
|
| 229 |
|
| 230 |
+
This project is licensed under the Apache 2.0 License.
|
|
|
|
|
|
|
|
|
|
| 231 |
|
| 232 |
---
|
| 233 |
|
| 234 |
+
For more detailed information on each component and advanced configurations, please refer to the [documentation](https://huggingface.co/RobbiePasquale/lightbulb).
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|