File size: 3,068 Bytes
37e5bdb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
# %%
from pathlib import Path

import torch
from sentence_transformers import SentenceTransformer


# %%


def create_splits(p):
    # Create prompts with each word omitted
    words = p.split()
    omit_prompts = [
        " ".join(w for i, w in enumerate(words) if i != j) for j in range(len(words))
    ]
    return words, omit_prompts


# %%
from abc import ABC, abstractmethod


class IE(ABC):
    @abstractmethod
    def get_word_importance_chunked(self, PROMPT):
        pass


class ImportanceEvaluatorStatic(IE):
    def __init__(self):
        # Download from the  Hub
        self.CLIP_MODEL_ID = "sentence-transformers/static-retrieval-mrl-en-v1"
        self.model = SentenceTransformer(self.CLIP_MODEL_ID)

    def get_word_importance(self, PROMPT):
        words, omit_prompts = create_splits(PROMPT)

        sentences = [PROMPT] + omit_prompts

        embeddings = self.model.encode(sentences)

        similarities = self.model.similarity(embeddings[0:1], embeddings)

        x = similarities[0]
        x = -x.log()  # importance of a word is the inverse of similarity-when-dropped
        x = x - x[0]  # subtract self-similarity as the baseline

        x = x.clamp(0)
        x /= x.max()
        return x[1:], words

    def get_word_importance_chunked(self, PROMPT):
        return self.get_word_importance(PROMPT)


# %%


def compute_static_word_importances(
    f: Path, ie: ImportanceEvaluatorStatic, overwrite=False
):
    model_id = ie.CLIP_MODEL_ID
    for c in f.glob(".captions/*"):
        metadir = c / ".meta"
        for file in c.iterdir():
            if file.suffix == ".txt" and file.is_file():
                # print(file)
                try:
                    out = metadir / file.with_suffix(".pth").name
                    r = {}
                    if out.exists():
                        r = torch.load(out, weights_only=False)
                        assert isinstance(r, dict), "corrupt format"
                        if (not overwrite) and (model_id in r):
                            continue

                    caption = file.read_text()
                    if (model_id not in r) or overwrite:
                        importances = [
                            ie.get_word_importance_chunked(l) if l else None
                            for l in caption.split("\n")
                        ]
                        r[model_id] = importances

                    metadir.mkdir(exist_ok=True)
                    torch.save(r, out)
                except Exception as e:
                    print("ERROR", out, e)


def yield_dirs(root: Path):
    for subset in root.iterdir():
        if not subset.is_dir():
            if subset.name.startswith("."):
                continue
            yield subset


if __name__ == "__main__":
    ies = ImportanceEvaluatorStatic()
    root = Path("/path_to_my_files")
    dfs = []
    from tqdm import tqdm
    pb = tqdm()
    for f in yield_dirs(root, True):
        pb.update(1)
        print(f)
        compute_static_word_importances(f, ies, overwrite=False)