from BeamDiffusionModel.tree.node import Node import numpy as np class BeamSearchTree: def __init__(self, steps_back, beam_width, fixed_goals, n_steps): self.root = Node(None, None, None,None, "Rand Seed", None, None) self.nodes = [self.root] self.n_steps = n_steps self.steps_back = steps_back self.fixed_goals = fixed_goals self.beam_width = beam_width def add_node(self, parent, step_text, step, generated_from_step, generated_from_latent, image, latents, heat_map): new_node = Node(step_text, image, latents, heat_map, step, generated_from_step, generated_from_latent, parent) parent.add_child(new_node) self.nodes.append(new_node) return new_node def all_paths(self): paths = [] # iterate over all the subtrees and get all the paths for root in self.root.children: paths.append(root.find_all_paths()) return paths def get_previous_steps_features(self, node): previous_steps = [] previous_images = [] current = node.parent while current is not None and current.parent is not None: text_embedding, img_embedding = current.get_features() previous_steps = [text_embedding] + previous_steps previous_images = [img_embedding] + previous_images current = current.parent return previous_steps, previous_images def best_path(self): all_paths = self.all_paths() best_path = None best_score = float('-inf') for paths in all_paths: for path in paths: if len(path) < self.n_steps: continue score = 0.0 for node in path: score += np.log(node.softmax) if node.softmax > 0 else node.softmax if score > best_score: best_score = score best_path = path return best_path, best_score def best_path_imgs(self): best_path, _ = self.best_path() seq = [] for node in best_path: seq.append(node.image) return seq def get_n_best_paths(self, n, path_length): all_paths = self.all_paths() best_paths = [] for paths in all_paths: for path in paths: if len(path) < path_length: continue score = float('-inf') for node in path: score += np.log(node.softmax) if node.softmax > 0 else node.softmax if len(best_paths) < n: best_paths.append((path, score)) else: min_score = min(best_paths, key=lambda x: x[1]) if score > min_score[1]: best_paths.remove(min_score) best_paths.append((path, score)) new_best_paths = [] for (path, score) in best_paths: new_best_paths.append(path) return new_best_paths