File size: 2,093 Bytes
173ea2b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from BeamDiffusionModel.models.clip.clip import clip
import numpy as np

class Node:
    def __init__(self, step_text, image_path, latents, heat_map, step, generated_from_step, generated_from_latent, parent=None):
        self.step_text = step_text
        self.step = step
        self.generated_from_step = generated_from_step
        self.generated_from_latent = generated_from_latent
        self.image = image_path
        self.latents = latents
        self.heat_map = heat_map
        self.parent = parent
        self.softmax = 0.
        self.children = []


    def get_ancestors(self, n):
        ancestors = [self]
        current_node = self.parent
        while current_node is not None and current_node.parent is not None and n > 0:
            ancestors.append(current_node)
            current_node = current_node.parent
            n -= 1
        return ancestors

    def set_softmax(self, softmax, n_latents, n_max_latents, epsilon=1e-6):
        self.softmax = np.log((softmax.cpu() * (1 / n_max_latents)) / (1 / n_latents) + epsilon)

    def get_softmax(self):
        return self.softmax

    def add_child(self, node):
        self.children.append(node)

    def find_all_paths(self):
        all_paths = []
        stack = [(self, [self])]  # Stack holds tuples of (current_node, path_to_current_node)

        while stack:
            current_node, path = stack.pop()

            # If the current node is a leaf, add the path to all_paths
            if not current_node.children:
                all_paths.append(path)
            else:
                # Add all children to the stack with the updated path
                for child in current_node.children:
                    stack.append((child, path + [child]))

        return all_paths

    def get_latent(self, idx):
        return self.latents[idx]

    def get_features(self):
        text_embedding = clip.generate_embedding(self.step_text)
        img_embedding = clip.generate_embedding(self.image)
        return text_embedding, img_embedding