| |
|
| | import math
|
| | import random
|
| | from nltk.corpus import wordnet
|
| | from scrapy.crawler import CrawlerRunner
|
| | from scrapy.utils.log import configure_logging
|
| | from scrapy.utils.project import get_project_settings
|
| | from twisted.internet import reactor, defer
|
| | from scrapy import signals
|
| | import logging
|
| | from my_search_engine.my_search_engine.spiders.search_spider import SearchSpider
|
| | from sentence_transformers import SentenceTransformer, util
|
| | from ranking import train_ranking_model
|
| | import time
|
| |
|
| | logger = logging.getLogger(__name__)
|
| |
|
| | class MCTSNode:
|
| | def __init__(self, state, parent=None, action=None):
|
| | self.state = state
|
| | self.parent = parent
|
| | self.action = action
|
| | self.children = []
|
| | self.visits = 0
|
| | self.value = 0
|
| | self.ucb_score = float('inf')
|
| |
|
| | def is_leaf(self):
|
| | return len(self.children) == 0
|
| |
|
| | def add_child(self, child_state, action=None):
|
| | child_node = MCTSNode(child_state, parent=self, action=action)
|
| | self.children.append(child_node)
|
| | return child_node
|
| |
|
| | def update(self, reward):
|
| | self.visits += 1
|
| | self.value += reward
|
| | if self.parent:
|
| | self.ucb_score = self.calculate_ucb()
|
| |
|
| | def calculate_ucb(self, exploration_weight=1.41):
|
| | if self.visits == 0 or not self.parent:
|
| | return float('inf')
|
| | exploitation = self.value / self.visits
|
| | exploration = exploration_weight * math.sqrt(math.log(self.parent.visits) / self.visits)
|
| | return exploitation + exploration
|
| |
|
| | class MCTS:
|
| | def __init__(self, initial_state, num_simulations=20, exploration_weight=1.41):
|
| | self.root = MCTSNode(initial_state)
|
| | self.num_simulations = num_simulations
|
| | self.exploration_weight = exploration_weight
|
| | self.query_model = SentenceTransformer('all-MiniLM-L6-v2')
|
| | self.results = []
|
| | self.crawler_runner = CrawlerRunner(get_project_settings())
|
| | self.initial_state = initial_state
|
| | self.num_iterations = 5
|
| |
|
| | def select(self, node):
|
| | while not node.is_leaf():
|
| | if not node.children:
|
| | return node
|
| | node = max(node.children, key=lambda c: c.calculate_ucb(self.exploration_weight))
|
| | return node
|
| |
|
| | def expand(self, node):
|
| | if node.visits == 0:
|
| | return node
|
| | possible_refinements = self.get_possible_refinements(node.state)
|
| | for refinement in possible_refinements:
|
| | node.add_child(refinement)
|
| | return random.choice(node.children) if node.children else node
|
| |
|
| | def calculate_combined_reward(self, ranking_score, state):
|
| | state_length_reward = len(state) / 100
|
| | if state:
|
| | query_complexity = len(set(state.split())) / len(state.split())
|
| | else:
|
| | query_complexity = 0
|
| | semantic_similarity = self.calculate_semantic_similarity(state, self.root.state)
|
| |
|
| | combined_reward = (
|
| | 0.5 * ranking_score +
|
| | 0.2 * state_length_reward +
|
| | 0.2 * query_complexity +
|
| | 0.1 * semantic_similarity
|
| | )
|
| | return combined_reward
|
| |
|
| | def calculate_semantic_similarity(self, query1, query2):
|
| | embedding1 = self.query_model.encode(query1)
|
| | embedding2 = self.query_model.encode(query2)
|
| | return util.pytorch_cos_sim(embedding1, embedding2).item()
|
| |
|
| | def backpropagate(self, node, reward):
|
| | while node is not None:
|
| | node.update(reward)
|
| | node = node.parent
|
| |
|
| | def best_action(self):
|
| | if not self.root.children:
|
| | return self.root
|
| |
|
| | def score(node):
|
| | if node.visits == 0:
|
| | return float('-inf')
|
| | return node.value / node.visits
|
| |
|
| | return max(self.root.children, key=score)
|
| |
|
| | def refine_query(self, query):
|
| | words = query.split()
|
| | refined_query = []
|
| |
|
| | for word in words:
|
| | if word.lower() not in {"how", "to", "get", "an", "the", "and", "or", "of", "build"}:
|
| | synonyms = wordnet.synsets(word)
|
| | if synonyms:
|
| | synonym_words = [lemma.name() for lemma in synonyms[0].lemmas()
|
| | if len(lemma.name().split()) == 1 and word != lemma.name()]
|
| | if synonym_words:
|
| | refined_query.append(random.choice(synonym_words))
|
| | else:
|
| | refined_query.append(word)
|
| | else:
|
| | refined_query.append(word)
|
| | else:
|
| | refined_query.append(word)
|
| |
|
| | possible_intent_keywords = ['guide', 'tutorial', 'LLM', 'language model', 'NLP', 'GPT']
|
| | refined_query.append(random.choice(possible_intent_keywords))
|
| |
|
| | return ' '.join(refined_query)
|
| |
|
| | def get_related_queries(self, query):
|
| | query_embedding = self.query_model.encode(query)
|
| | refined_query_variations = [query]
|
| | words_to_avoid = {'how', 'to', 'get'}
|
| | words = query.split()
|
| |
|
| | for word in words:
|
| | if word.lower() not in words_to_avoid:
|
| | synonyms = wordnet.synsets(word)
|
| | if synonyms:
|
| | synonym_words = [lemma.name() for lemma in synonyms[0].lemmas() if lemma.name() != word]
|
| | if synonym_words:
|
| | refined_query = query.replace(word, random.choice(synonym_words))
|
| | refined_query_variations.append(refined_query)
|
| |
|
| | refined_query_variations = list(set(refined_query_variations))
|
| | refined_query_embeddings = [self.query_model.encode(variation) for variation in refined_query_variations]
|
| | similarity_scores = util.pytorch_cos_sim(query_embedding, refined_query_embeddings).tolist()[0]
|
| |
|
| | similarity_threshold = 0.8
|
| | filtered_queries = [variation for idx, variation in enumerate(refined_query_variations)
|
| | if similarity_scores[idx] > similarity_threshold]
|
| |
|
| | return filtered_queries[:2] if filtered_queries else [query]
|
| |
|
| | def get_possible_refinements(self, query):
|
| | refined_queries = self.get_related_queries(query)
|
| | return refined_queries + [self.refine_query(query)]
|
| |
|
| | @defer.inlineCallbacks
|
| | def web_search(self, query, search_sites=None):
|
| | if not query.strip():
|
| | logger.error("Cannot perform web search with an empty query.")
|
| | defer.returnValue([])
|
| |
|
| | logger.info(f"Starting web search for query: {query}")
|
| | configure_logging(install_root_handler=False)
|
| | logging.basicConfig(level=logging.INFO)
|
| |
|
| | results = []
|
| |
|
| | def crawler_results(item, response, spider):
|
| | logger.info(f"Received result: {item['title']}")
|
| | results.append(item)
|
| |
|
| | try:
|
| | crawler = self.crawler_runner.create_crawler(SearchSpider)
|
| | crawler.signals.connect(crawler_results, signal=signals.item_scraped)
|
| |
|
| |
|
| | yield self.crawler_runner.crawl(crawler, query=query, search_sites=search_sites)
|
| | except Exception as e:
|
| | logger.error(f"Error during web search: {str(e)}")
|
| | defer.returnValue([])
|
| |
|
| | logger.info(f"Web search completed. Found {len(results)} results.")
|
| | defer.returnValue(results)
|
| |
|
| | @defer.inlineCallbacks
|
| | def run(self):
|
| | logger.info(f"Starting MCTS run with {self.num_iterations} iterations")
|
| | for i in range(self.num_iterations):
|
| | logger.debug(f"Iteration {i+1}/{self.num_iterations}")
|
| | leaf = self.select(self.root)
|
| | child = self.expand(leaf)
|
| | reward = yield self.simulate(child)
|
| | self.backpropagate(child, reward)
|
| |
|
| | best_child = self.best_action()
|
| | logger.info(f"MCTS run completed. Best action: {best_child.state}")
|
| | defer.returnValue(best_child.state if best_child != self.root else self.root.state)
|
| |
|
| | @defer.inlineCallbacks
|
| | def simulate(self, node):
|
| | query_results = yield self.web_search(node.state)
|
| | ranked_results = train_ranking_model(node.state, query_results)
|
| |
|
| | if ranked_results:
|
| | top_score = ranked_results[0]['predicted_score']
|
| | else:
|
| | top_score = 0
|
| |
|
| | reward = self.calculate_combined_reward(top_score, node.state)
|
| | defer.returnValue(reward)
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|