RobbiePasquale commited on
Commit
2991289
·
verified ·
1 Parent(s): b29ed0e

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +59 -141
README.md CHANGED
@@ -283,7 +283,6 @@ python main_menu.py --task advanced_inference --query "Your complex query here"
283
  3. **Thought Sequence Generation:** Produces a sequence of interconnected thoughts/actions.
284
  4. **Final Response Generation:** Synthesizes the best thought path into a coherent response.
285
 
286
-
287
  ---
288
 
289
  ### **Mode: With World Model and Tree of Thought**
@@ -291,162 +290,97 @@ python main_menu.py --task advanced_inference --query "Your complex query here"
291
  #### **Step 1: Input Tokenization and Encoding**
292
 
293
  1. **Tokenization:**
294
- \[
295
- \text{input\_ids} = \text{tokenizer.encode(query, return\_tensors='pt')}
296
- \]
297
- - **Shape:** \((\text{batch\_size}=1, \text{seq\_len})\)
298
 
299
  2. **Encoding via Transformer:**
300
- \[
301
- \text{transformer\_output} = \text{model\_transformer}(input\_ids, input\_ids)
302
- \]
303
- - **Shape:** \((\text{batch\_size}=1, \text{seq\_len}, \text{d\_model})\)
304
 
305
  3. **State Representation:**
306
- \[
307
- \text{initial\_representation} = \text{RepresentationNetwork}(\text{transformer\_output})[:, -1, :].unsqueeze(1)
308
- \]
309
- - **Shape:** \((\text{batch\_size}=1, 1, \text{state\_dim})\)
310
 
311
  4. **State Initialization:**
312
- \[
313
- \text{initial\_state} = \text{State}(\text{representation}=\text{initial\_representation}, \text{dynamics\_network}=dynamics\_network, \text{action\_encoder}=action\_encoder, \text{thought\_node}=root\_thought\_node)
314
- \]
315
 
316
  #### **Step 2: MCTS Initialization and Root Node Evaluation**
317
 
318
  1. **MCTS Instance Creation:**
319
- \[
320
- \text{mcts} = \text{MCTS}(\text{prediction\_network}, \text{dynamics\_network}, \text{action\_encoder}, \text{num\_iterations}=mcts\_iterations, \text{exploration\_constant}=exploration\_constant)
321
- \]
322
 
323
  2. **Root Node Creation:**
324
- \[
325
- \text{root\_node} = \text{MCTSNode}(\text{state}=\text{initial\_state}, \text{thought\_node}=root\_thought\_node)
326
- \]
327
 
328
  3. **Root Node Evaluation:**
329
- \[
330
- \text{value\_estimate} = \text{mcts.evaluate}(\text{root\_node})
331
- \]
332
 
333
  4. **Backpropagation:**
334
- \[
335
- \text{mcts.backpropagate}(\text{root\_node}, \text{value\_estimate})
336
- \]
337
 
338
  #### **Step 3: MCTS Iterations with Beam Search**
339
 
340
  1. **Beam Initialization:**
341
- \[
342
- \text{beam} = \left\{ \left( \text{root\_node}, \text{score} = 0, \text{cum\_entropy} = 0, \text{cum\_variance} = 0, \text{action\_sequence} = [] \right) \right\}
343
- \]
344
 
345
  2. **Iterative Expansion:**
346
- - **For each iteration up to \(\text{num\_iterations}\):**
347
 
348
  - **Candidate Collection:**
349
  - **For each node in the current beam:**
350
 
351
  - **Leaf Evaluation:**
352
- - **If** \(\text{node.is\_leaf()} = \text{True}\):
353
- \[
354
- \text{value\_estimate} = \text{mcts.evaluate}(\text{node})
355
- \]
356
- \[
357
- \text{mcts.backpropagate}(\text{node}, \text{value\_estimate})
358
- \]
359
 
360
  - **Child Selection:**
361
- - **If** \(\text{node.children}\) **is not empty:**
362
- - **Calculate total visits:**
363
- \[
364
- \text{total\_visits} = \sum_{\text{child} \in \text{node.children}} \text{child.visit\_count}
365
- \]
366
-
367
- - **Select top \(\text{beam\_size}\) actions based on UCB scores:**
368
- \[
369
- \text{sorted\_children} = \text{sorted}(\text{node.children.items()}, \text{key}=\lambda \text{item}: \text{item}[1].ucb\_score(\text{total\_visits}, \text{exploration\_constant}), \text{reverse}=True)[:\text{beam\_size}]
370
- \]
371
 
372
- - **Action Sequence Prediction:**
373
- - **For each selected action:**
 
 
 
 
 
 
 
 
 
 
 
374
 
375
- - **Initialize:**
376
- - \(\text{current\_node} = \text{selected\_node}\)
377
- - \(\text{current\_sequence} = [\text{selected\_action}]\)
378
- - \(\text{current\_score} = 0\)
379
- - \(\text{current\_entropy} = 0\)
380
- - \(\text{current\_variance} = 0\)
381
 
382
- - **Multi-Token Prediction:**
383
- - **For each step in \(\text{n\_tokens\_predict}\):**
384
-
385
- - **If** \(\text{current\_node.is\_leaf()} = \text{True}\):
386
- \[
387
- \text{value\_estimate} = \text{mcts.evaluate}(\text{current\_node})
388
- \]
389
- \[
390
- \text{mcts.backpropagate}(\text{current\_node}, \text{value\_estimate})
391
- \]
392
-
393
- - **If** \(\text{current\_node.children}\) **is empty**, **break**.
394
-
395
- - **Calculate total visits for the new node:**
396
- \[
397
- \text{total\_visits} = \sum_{\text{child} \in \text{current\_node.children}} \text{child.visit\_count}
398
- \]
399
-
400
- - **Select the action with the highest UCB score:**
401
- \[
402
- (\text{next\_action}, \text{next\_node}) = \text{max}(\text{current\_node.children.items()}, \text{key}=\lambda \text{item}: \text{item}[1].ucb\_score(\text{total\_visits}, \text{exploration\_constant}))
403
- \]
404
-
405
- - **Score Update:**
406
- \[
407
- \text{current\_score} +=
408
- \begin{cases}
409
- \frac{\text{next\_node.value\_sum}}{\text{next\_node.visit\_count}} & \text{if } \text{next\_node.visit\_count} > 0 \\
410
- 0 & \text{otherwise}
411
- \end{cases}
412
- \]
413
-
414
- - **Entropy and Variance Update:**
415
- \[
416
- \text{current\_entropy} += \text{next\_node.entropy}
417
- \]
418
- \[
419
- \text{current\_variance} += \text{next\_node.variance}
420
- \]
421
-
422
- - **Append next action to the sequence:**
423
- \[
424
- \text{current\_sequence}.append(\text{next\_action})
425
- \]
426
-
427
- - **Update current node:**
428
- \[
429
- \text{current\_node} = \text{next\_node}
430
- \]
431
 
432
- - **Candidate Aggregation:**
433
- \[
434
- \text{all\_candidates.append}((\text{current\_node}, \text{current\_score}, \text{current\_entropy}, \text{current\_variance}, \text{current\_sequence}))
435
- \]
 
 
 
 
 
436
 
437
  - **Beam Pruning:**
438
- \[
439
- \text{beam} = \text{sorted}(\text{all\_candidates}, \text{key}=\lambda x: x[1] - 0.1 \times x[2] + 0.05 \times x[3], \text{reverse}=True)[:\text{beam\_size}]
440
- \]
441
 
442
  3. **Termination:**
443
- - **Stop early** if no candidates remain or all beams have reached terminal nodes.
444
-
445
  4. **Result Extraction:**
446
- \[
447
- \text{best\_sequence} = \text{beam}[0][4]
448
- \]
449
- - **Return:** \(\text{best\_sequence}\) as the generated sequence of actions (thoughts).
450
 
451
  ---
452
 
@@ -455,8 +389,8 @@ python main_menu.py --task advanced_inference --query "Your complex query here"
455
  #### **1. Step 1: Input Tokenization and Encoding**
456
 
457
  - **Tokenization:** The input query is converted into token IDs using the tokenizer. This numerical representation is essential for processing by the Transformer model.
458
-
459
- - **Encoding via Transformer:** The tokenized input is passed through the Transformer model to generate contextual embeddings (`transformer_output`). These embeddings capture the semantic information of the input.
460
 
461
  - **State Representation:** The `RepresentationNetwork` processes the transformer's output to create a condensed state representation. This state serves as the foundation for further reasoning steps.
462
 
@@ -468,7 +402,9 @@ python main_menu.py --task advanced_inference --query "Your complex query here"
468
 
469
  - **Root Node Creation:** A `MCTSNode` representing the root of the search tree is created, associated with the initial state and the root thought node.
470
 
471
- - **Root Node Evaluation:** The root node is evaluated using the MCTS's evaluation function, which assesses the potential value of the current state. The result is then backpropagated to update the node's statistics.
 
 
472
 
473
  #### **3. Step 3: MCTS Iterations with Beam Search**
474
 
@@ -492,7 +428,7 @@ python main_menu.py --task advanced_inference --query "Your complex query here"
492
 
493
  - **Candidate Aggregation:** All potential candidates resulting from the action predictions are collected for the current iteration.
494
 
495
- - **Beam Pruning:** After all candidates are collected, the beam is pruned to retain only the top sequences based on a scoring function that balances score, entropy, and variance. This ensures that only the most promising action sequences are retained for further exploration.
496
 
497
  - **Termination:** The iterative process continues until the specified number of MCTS iterations is reached or all beams have been exhausted.
498
 
@@ -571,24 +507,6 @@ graph TD
571
  R --> Q
572
  ```
573
 
574
- ---
575
-
576
- ### **Summary**
577
-
578
- Your advanced inference mechanism integrates several sophisticated components to enable strategic and coherent response generation:
579
-
580
- 1. **World Model:** Provides a structured understanding of the current state and predicts future states based on actions.
581
- 2. **Tree of Thought (ToT):** Offers a hierarchical framework for exploring diverse reasoning paths.
582
- 3. **Monte Carlo Tree Search (MCTS):** Efficiently navigates the ToT by balancing exploration and exploitation using UCB scores influenced by entropy and variance.
583
- 4. **Beam Search with Multi-Token Prediction:** Enhances generation efficiency and coherence by predicting multiple tokens simultaneously and maintaining multiple candidate sequences.
584
- 5. **Entropy and Variance:** Quantify uncertainty and diversity, guiding the search process to generate balanced and robust responses.
585
-
586
- This comprehensive system enables the model to handle complex queries by simulating strategic reasoning, exploring various thought pathways, and generating informed and coherent responses.
587
-
588
- ---
589
-
590
- Feel free to integrate this Markdown into your HuggingFace documentation or repository. If you need further customization or additional sections, let me know!
591
-
592
  ## General Arguments
593
 
594
  | Argument | Required | Description | Default |
 
283
  3. **Thought Sequence Generation:** Produces a sequence of interconnected thoughts/actions.
284
  4. **Final Response Generation:** Synthesizes the best thought path into a coherent response.
285
 
 
286
  ---
287
 
288
  ### **Mode: With World Model and Tree of Thought**
 
290
  #### **Step 1: Input Tokenization and Encoding**
291
 
292
  1. **Tokenization:**
293
+ - The input query is converted into token IDs using the tokenizer. This numerical representation is essential for processing by the Transformer model.
294
+ - **Shape:** The resulting tensor has a shape corresponding to the batch size and the sequence length of the input.
 
 
295
 
296
  2. **Encoding via Transformer:**
297
+ - The tokenized input is passed through the Transformer model to generate contextual embeddings. These embeddings capture the semantic information of the input.
298
+ - **Shape:** The output tensor includes the batch size, sequence length, and the dimensionality of the model.
 
 
299
 
300
  3. **State Representation:**
301
+ - The `RepresentationNetwork` processes the transformer's output to create a condensed state representation. This state serves as the foundation for further reasoning steps.
302
+ - **Shape:** The state representation tensor includes the batch size, a single sequence length (typically one), and the state dimensionality.
 
 
303
 
304
  4. **State Initialization:**
305
+ - A `State` object is created, encapsulating the initial representation, dynamics network, action encoder, and the root node of the Tree of Thought. This object maintains the current context and facilitates state transitions as actions are applied.
 
 
306
 
307
  #### **Step 2: MCTS Initialization and Root Node Evaluation**
308
 
309
  1. **MCTS Instance Creation:**
310
+ - An instance of the `MCTS` class is initialized with the necessary networks and parameters. This instance will manage the search process through the Tree of Thought.
311
+ - Key parameters include the prediction network, dynamics network, action encoder, number of iterations, and the exploration constant.
 
312
 
313
  2. **Root Node Creation:**
314
+ - A `MCTSNode` representing the root of the search tree is created. This node is associated with the initial state and the root thought node from the Tree of Thought.
 
 
315
 
316
  3. **Root Node Evaluation:**
317
+ - The root node is evaluated using the MCTS's evaluation function, which assesses the potential value of the current state based on the prediction network's output.
 
 
318
 
319
  4. **Backpropagation:**
320
+ - The evaluation result (value estimate) is backpropagated through the tree, updating the visit counts and value sums of the nodes. This process informs future selections by providing aggregated value information.
 
 
321
 
322
  #### **Step 3: MCTS Iterations with Beam Search**
323
 
324
  1. **Beam Initialization:**
325
+ - The search beam is initialized with the root node. Each beam element includes the current node, a cumulative score, cumulative entropy, cumulative variance, and an empty action sequence.
 
 
326
 
327
  2. **Iterative Expansion:**
328
+ - **For each iteration up to the specified number of MCTS iterations:**
329
 
330
  - **Candidate Collection:**
331
  - **For each node in the current beam:**
332
 
333
  - **Leaf Evaluation:**
334
+ - If the node is a leaf (has no children), it is evaluated to estimate its value. The evaluation result is then backpropagated to update the node's statistics.
 
 
 
 
 
 
335
 
336
  - **Child Selection:**
337
+ - If the node has children, the total number of visits to all its children is calculated. The children are then sorted based on their Upper Confidence Bound (UCB) scores, which consider exploration and exploitation factors. The top actions (up to the beam size) are selected for expansion.
338
+
339
+ - **Action Sequence Prediction:**
340
+ - **For each selected action:**
 
 
 
 
 
 
341
 
342
+ - **Initialization:**
343
+ - Set the current node to the selected child node.
344
+ - Initialize the current action sequence with the selected action.
345
+ - Initialize the current score, cumulative entropy, and cumulative variance to zero.
346
+
347
+ - **Multi-Token Prediction:**
348
+ - **For each step in the number of tokens to predict:**
349
+
350
+ - **Leaf Evaluation:**
351
+ - If the current node is a leaf, evaluate it to obtain a value estimate and backpropagate the result.
352
+
353
+ - **Child Check:**
354
+ - If the current node has no children, exit the multi-token prediction loop for this sequence.
355
 
356
+ - **Action Selection:**
357
+ - Calculate the total number of visits to all children of the current node.
358
+ - Select the action with the highest UCB score from the children.
 
 
 
359
 
360
+ - **Score Update:**
361
+ - If the selected child node has been visited before, increment the current score by the average value estimate of that node.
362
+ - If the child node has not been visited, the score remains unchanged or is updated with a default value.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
363
 
364
+ - **Entropy and Variance Update:**
365
+ - Accumulate the entropy and variance metrics from the selected node to guide the search towards more confident and diverse actions.
366
+
367
+ - **Sequence Extension:**
368
+ - Append the selected action to the current action sequence.
369
+ - Update the current node to the selected child node.
370
+
371
+ - **Candidate Aggregation:**
372
+ - Add the new candidate sequence, along with its updated score, entropy, variance, and action sequence, to the list of all candidates for this iteration.
373
 
374
  - **Beam Pruning:**
375
+ - After collecting all candidates from the current beam, sort them based on a scoring function that balances the cumulative score, entropy, and variance.
376
+ - Retain only the top sequences up to the specified beam size to form the new beam for the next iteration.
 
377
 
378
  3. **Termination:**
379
+ - The iterative expansion process continues until the specified number of MCTS iterations is reached or there are no more candidates to explore.
380
+
381
  4. **Result Extraction:**
382
+ - After completing the iterations, select the best action sequence from the final beam based on the accumulated scores and metrics.
383
+ - Return this sequence as the generated series of actions (thoughts) in response to the input query.
 
 
384
 
385
  ---
386
 
 
389
  #### **1. Step 1: Input Tokenization and Encoding**
390
 
391
  - **Tokenization:** The input query is converted into token IDs using the tokenizer. This numerical representation is essential for processing by the Transformer model.
392
+
393
+ - **Encoding via Transformer:** The tokenized input is passed through the Transformer model to generate contextual embeddings. These embeddings capture the semantic information of the input.
394
 
395
  - **State Representation:** The `RepresentationNetwork` processes the transformer's output to create a condensed state representation. This state serves as the foundation for further reasoning steps.
396
 
 
402
 
403
  - **Root Node Creation:** A `MCTSNode` representing the root of the search tree is created, associated with the initial state and the root thought node.
404
 
405
+ - **Root Node Evaluation:** The root node is evaluated using the MCTS's evaluation function, which assesses the potential value of the current state based on the prediction network's output.
406
+
407
+ - **Backpropagation:** The evaluation result (value estimate) is backpropagated through the tree, updating the visit counts and value sums of the nodes. This process informs future selections by providing aggregated value information.
408
 
409
  #### **3. Step 3: MCTS Iterations with Beam Search**
410
 
 
428
 
429
  - **Candidate Aggregation:** All potential candidates resulting from the action predictions are collected for the current iteration.
430
 
431
+ - **Beam Pruning:** After collecting all candidates, the beam is pruned to retain only the top sequences based on a scoring function that balances score, entropy, and variance. This ensures that only the most promising action sequences are retained for further exploration.
432
 
433
  - **Termination:** The iterative process continues until the specified number of MCTS iterations is reached or all beams have been exhausted.
434
 
 
507
  R --> Q
508
  ```
509
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
510
  ## General Arguments
511
 
512
  | Argument | Required | Description | Default |