|
|
from BeamDiffusionModel.tree.node import Node
|
|
|
import numpy as np
|
|
|
|
|
|
class BeamSearchTree:
|
|
|
def __init__(self, steps_back, beam_width, fixed_goals, n_steps):
|
|
|
self.root = Node(None, None, None,None, "Rand Seed", None, None)
|
|
|
self.nodes = [self.root]
|
|
|
self.n_steps = n_steps
|
|
|
self.steps_back = steps_back
|
|
|
self.fixed_goals = fixed_goals
|
|
|
self.beam_width = beam_width
|
|
|
|
|
|
def add_node(self, parent, step_text, step, generated_from_step, generated_from_latent, image, latents, heat_map):
|
|
|
new_node = Node(step_text, image, latents, heat_map, step, generated_from_step, generated_from_latent, parent)
|
|
|
parent.add_child(new_node)
|
|
|
self.nodes.append(new_node)
|
|
|
return new_node
|
|
|
|
|
|
def all_paths(self):
|
|
|
paths = []
|
|
|
|
|
|
for root in self.root.children:
|
|
|
paths.append(root.find_all_paths())
|
|
|
return paths
|
|
|
|
|
|
def get_previous_steps_features(self, node):
|
|
|
previous_steps = []
|
|
|
previous_images = []
|
|
|
current = node.parent
|
|
|
|
|
|
while current is not None and current.parent is not None:
|
|
|
text_embedding, img_embedding = current.get_features()
|
|
|
previous_steps = [text_embedding] + previous_steps
|
|
|
previous_images = [img_embedding] + previous_images
|
|
|
current = current.parent
|
|
|
return previous_steps, previous_images
|
|
|
|
|
|
def best_path(self):
|
|
|
all_paths = self.all_paths()
|
|
|
best_path = None
|
|
|
best_score = float('-inf')
|
|
|
for paths in all_paths:
|
|
|
for path in paths:
|
|
|
if len(path) < self.n_steps:
|
|
|
continue
|
|
|
score = 0.0
|
|
|
for node in path:
|
|
|
score += np.log(node.softmax) if node.softmax > 0 else node.softmax
|
|
|
|
|
|
if score > best_score:
|
|
|
best_score = score
|
|
|
best_path = path
|
|
|
return best_path, best_score
|
|
|
|
|
|
def best_path_imgs(self):
|
|
|
best_path, _ = self.best_path()
|
|
|
seq = []
|
|
|
for node in best_path:
|
|
|
seq.append(node.image)
|
|
|
return seq
|
|
|
|
|
|
def get_n_best_paths(self, n, path_length):
|
|
|
all_paths = self.all_paths()
|
|
|
best_paths = []
|
|
|
for paths in all_paths:
|
|
|
for path in paths:
|
|
|
if len(path) < path_length:
|
|
|
continue
|
|
|
score = float('-inf')
|
|
|
for node in path:
|
|
|
score += np.log(node.softmax) if node.softmax > 0 else node.softmax
|
|
|
if len(best_paths) < n:
|
|
|
best_paths.append((path, score))
|
|
|
else:
|
|
|
min_score = min(best_paths, key=lambda x: x[1])
|
|
|
if score > min_score[1]:
|
|
|
best_paths.remove(min_score)
|
|
|
best_paths.append((path, score))
|
|
|
|
|
|
new_best_paths = []
|
|
|
for (path, score) in best_paths:
|
|
|
new_best_paths.append(path)
|
|
|
|
|
|
return new_best_paths |