BeamDiffusion / tree /tree.py
Gui28F's picture
uploaded all project files
173ea2b verified
raw
history blame
3.14 kB
from BeamDiffusionModel.tree.node import Node
import numpy as np
class BeamSearchTree:
def __init__(self, steps_back, beam_width, fixed_goals, n_steps):
self.root = Node(None, None, None,None, "Rand Seed", None, None)
self.nodes = [self.root]
self.n_steps = n_steps
self.steps_back = steps_back
self.fixed_goals = fixed_goals
self.beam_width = beam_width
def add_node(self, parent, step_text, step, generated_from_step, generated_from_latent, image, latents, heat_map):
new_node = Node(step_text, image, latents, heat_map, step, generated_from_step, generated_from_latent, parent)
parent.add_child(new_node)
self.nodes.append(new_node)
return new_node
def all_paths(self):
paths = []
# iterate over all the subtrees and get all the paths
for root in self.root.children:
paths.append(root.find_all_paths())
return paths
def get_previous_steps_features(self, node):
previous_steps = []
previous_images = []
current = node.parent
while current is not None and current.parent is not None:
text_embedding, img_embedding = current.get_features()
previous_steps = [text_embedding] + previous_steps
previous_images = [img_embedding] + previous_images
current = current.parent
return previous_steps, previous_images
def best_path(self):
all_paths = self.all_paths()
best_path = None
best_score = float('-inf')
for paths in all_paths:
for path in paths:
if len(path) < self.n_steps:
continue
score = 0.0
for node in path:
score += np.log(node.softmax) if node.softmax > 0 else node.softmax
if score > best_score:
best_score = score
best_path = path
return best_path, best_score
def best_path_imgs(self):
best_path, _ = self.best_path()
seq = []
for node in best_path:
seq.append(node.image)
return seq
def get_n_best_paths(self, n, path_length):
all_paths = self.all_paths()
best_paths = []
for paths in all_paths:
for path in paths:
if len(path) < path_length:
continue
score = float('-inf')
for node in path:
score += np.log(node.softmax) if node.softmax > 0 else node.softmax
if len(best_paths) < n:
best_paths.append((path, score))
else:
min_score = min(best_paths, key=lambda x: x[1])
if score > min_score[1]:
best_paths.remove(min_score)
best_paths.append((path, score))
new_best_paths = []
for (path, score) in best_paths:
new_best_paths.append(path)
return new_best_paths