| |
|
| | import random
|
| | from typing import List, Dict, Any, Generator
|
| | from sentence_transformers import SentenceTransformer, util
|
| | import torch
|
| | import numpy as np
|
| | from twisted.internet import defer
|
| | from agent import AutonomousWebAgent
|
| | from mcts import MCTS, MCTSNode
|
| | import logging
|
| | from twisted.internet.defer import Deferred
|
| |
|
| |
|
| | logger = logging.getLogger(__name__)
|
| |
|
| | class ToTNode:
|
| | def __init__(self, thought, parent=None):
|
| | self.thought = thought
|
| | self.parent = parent
|
| | self.children = []
|
| | self.visits = 0
|
| | self.value = 0
|
| | self.search_results = []
|
| | self.mcts_node = None
|
| |
|
| | def add_child(self, child_thought):
|
| | child = ToTNode(child_thought, self)
|
| | self.children.append(child)
|
| | return child
|
| |
|
| | def update(self, reward):
|
| | self.visits += 1
|
| | self.value += reward
|
| |
|
| | class ToTSearch:
|
| | def __init__(self, agent: AutonomousWebAgent, model='all-MiniLM-L6-v2', max_depth=3, num_thoughts=3, num_simulations=100):
|
| | self.agent = agent
|
| | self.model = SentenceTransformer(model)
|
| | self.max_depth = max_depth
|
| | self.num_thoughts = num_thoughts
|
| | self.num_simulations = num_simulations
|
| | self.mcts = MCTS(initial_state="", num_simulations=num_simulations)
|
| |
|
| | def generate_thoughts(self, query: str) -> List[str]:
|
| | prompt = f"""Given the query "{query}", generate {self.num_thoughts} distinct thoughts or approaches to address it.
|
| | Each thought should be a complete sentence and offer a unique perspective or solution path."""
|
| |
|
| | thoughts = self.agent.generate_text(prompt).split('\n')
|
| | return [thought.strip() for thought in thoughts if thought.strip()]
|
| |
|
| | def expand_thought(self, thought: str) -> List[str]:
|
| | prompt = f"""Expand on the following thought: "{thought}"
|
| | Generate {self.num_thoughts} more specific sub-thoughts or considerations.
|
| | Each sub-thought should be a complete sentence and offer additional detail or a new angle."""
|
| |
|
| | expansions = self.agent.generate_text(prompt).split('\n')
|
| | return [exp.strip() for exp in expansions if exp.strip()]
|
| |
|
| | def evaluate_thought(self, thought: str, query: str) -> float:
|
| | thought_embedding = self.model.encode(thought)
|
| | query_embedding = self.model.encode(query)
|
| | return util.pytorch_cos_sim(thought_embedding, query_embedding).item()
|
| |
|
| | @defer.inlineCallbacks
|
| | def search_and_augment(self, thought: str) -> Generator[Deferred, Any, List[Dict[str, Any]]]:
|
| | search_results = yield self.agent.retrieve_from_web(thought)
|
| | for result in search_results:
|
| | result['originating_thought'] = thought
|
| | defer.returnValue(search_results)
|
| |
|
| | def select(self, node: ToTNode) -> ToTNode:
|
| | while node.children:
|
| |
|
| | if any(child.visits == 0 for child in node.children):
|
| | zero_visit_nodes = [child for child in node.children if child.visits == 0]
|
| | selected_node = random.choice(zero_visit_nodes)
|
| | logger.debug(f"Selected node with 0 visits: {selected_node.thought}")
|
| | return selected_node
|
| | else:
|
| | selected_node = max(node.children, key=lambda child: (child.value / child.visits) if child.visits > 0 else float('-inf'))
|
| | logger.debug(f"Selected node based on value/visits ratio: {selected_node.thought}, value: {selected_node.value}, visits: {selected_node.visits}")
|
| | return selected_node
|
| | return node
|
| |
|
| |
|
| | def expand(self, node: ToTNode, query: str) -> ToTNode:
|
| | if not node.children and len(node.thought.split()) > 2:
|
| | expansions = self.expand_thought(node.thought)
|
| | for expansion in expansions:
|
| | node.add_child(expansion)
|
| | return random.choice(node.children) if node.children else node
|
| |
|
| | @defer.inlineCallbacks
|
| | def simulate(self, node: ToTNode, query: str):
|
| | current_node = node
|
| | depth = 0
|
| | while depth < self.max_depth:
|
| | if not current_node.children:
|
| | break
|
| | current_node = random.choice(current_node.children)
|
| | depth += 1
|
| |
|
| | logger.debug(f"Simulating for thought: {current_node.thought}")
|
| |
|
| | search_results = yield self.search_and_augment(current_node.thought)
|
| | current_node.search_results = search_results
|
| |
|
| | logger.debug(f"Search results count: {len(search_results)}")
|
| |
|
| | ranked_results = self.agent.calculate_reward(current_node.thought, query)
|
| | logger.debug(f"Ranked results: {ranked_results}")
|
| |
|
| | mcts_node = MCTSNode(current_node.thought)
|
| | current_node.mcts_node = mcts_node
|
| | mcts_total_reward = 0
|
| |
|
| | for _ in range(self.num_simulations):
|
| | mcts_reward = yield self.mcts.simulate(mcts_node)
|
| | mcts_total_reward += mcts_reward
|
| | self.mcts.backpropagate(mcts_node, mcts_reward)
|
| |
|
| | logger.debug(f"MCTS node visits: {mcts_node.visits}, total reward: {mcts_total_reward}")
|
| |
|
| | if mcts_node.visits == 0 or ranked_results == 0:
|
| | logger.warning(f"Avoiding division by zero. MCTS visits: {mcts_node.visits}, Ranked results: {ranked_results}")
|
| | combined_reward = 0
|
| | else:
|
| | combined_reward = (ranked_results + mcts_value) / 2
|
| |
|
| | if mcts_node.visits > 0:
|
| | mcts_value = mcts_total_reward / mcts_node.visits
|
| | logger.debug(f"MCTS value: {mcts_value}")
|
| | else:
|
| | mcts_value = 0
|
| | logger.warning(f"MCTS node has 0 visits, assigning value 0")
|
| |
|
| | combined_reward = (ranked_results + mcts_value) / 2
|
| | logger.debug(f"Combined reward: {combined_reward}")
|
| |
|
| | defer.returnValue(combined_reward)
|
| |
|
| | def backpropagate(self, node: ToTNode, reward: float):
|
| | while node:
|
| | node.update(reward)
|
| | node = node.parent
|
| |
|
| | @defer.inlineCallbacks
|
| | def tot_search(self, query: str) -> Generator[Deferred, Any, ToTNode]:
|
| | root = ToTNode(query)
|
| | for _ in range(self.num_simulations):
|
| | node = self.select(root)
|
| | node = self.expand(node, query)
|
| | reward = yield self.simulate(node, query)
|
| | self.backpropagate(node, reward)
|
| |
|
| |
|
| | state = self.agent.extract_features(node.thought, query)
|
| | next_state = self.agent.extract_features(node.children[0].thought if node.children else node.thought, query)
|
| | self.agent.remember_worker(state, 0, reward, next_state, False)
|
| |
|
| |
|
| | self.agent.replay_worker()
|
| | self.agent.replay_manager()
|
| |
|
| | defer.returnValue(root)
|
| |
|
| | def get_best_path(self, root: ToTNode) -> List[str]:
|
| | path = [root.thought]
|
| | current = root
|
| | while current.children:
|
| | current = max(current.children, key=lambda child: child.value / child.visits if child.visits > 0 else float('-inf'))
|
| | path.append(current.thought)
|
| | return path
|
| |
|
| | @defer.inlineCallbacks
|
| | def synthesize_results(self, root: ToTNode, query: str) -> Generator[Deferred, Any, str]:
|
| | best_path = self.get_best_path(root)
|
| | all_results = []
|
| |
|
| | def collect_results(node):
|
| | all_results.extend(node.search_results)
|
| | for child in node.children:
|
| | collect_results(child)
|
| |
|
| | collect_results(root)
|
| |
|
| |
|
| | all_results.sort(key=lambda x: self.evaluate_thought(x['content'], query), reverse=True)
|
| |
|
| |
|
| | top_results = all_results[:5]
|
| | summary_prompt = f"Synthesize the following information into a coherent answer for the query '{query}':\n\n"
|
| | summary_prompt += f"Thought path: {' -> '.join(best_path)}\n\n"
|
| | for result in top_results:
|
| | summary_prompt += f"- {result['content'][:200]}...\n"
|
| |
|
| |
|
| | final_answer = yield self.agent.generate_rag_response(query, top_results)
|
| |
|
| |
|
| | self.agent.add_document_to_kb(
|
| | title=f"ToT Search Result: {query}",
|
| | content=final_answer,
|
| | metadata={"thought_path": best_path}
|
| | )
|
| |
|
| | defer.returnValue(final_answer)
|
| |
|
| | @defer.inlineCallbacks
|
| | def search(self, query: str) -> Generator[Deferred, Any, str]:
|
| | logger.info(f"Starting ToT search for query: {query}")
|
| | root = yield self.tot_search(query)
|
| | final_answer = yield self.synthesize_results(root, query)
|
| | logger.info(f"ToT search completed for query: {query}")
|
| | defer.returnValue(final_answer)
|
| |
|
| |
|
| |
|
| | |