Your Name
asd
37e5bdb
# %%
from pathlib import Path
import torch
from sentence_transformers import SentenceTransformer
# %%
def create_splits(p):
# Create prompts with each word omitted
words = p.split()
omit_prompts = [
" ".join(w for i, w in enumerate(words) if i != j) for j in range(len(words))
]
return words, omit_prompts
# %%
from abc import ABC, abstractmethod
class IE(ABC):
@abstractmethod
def get_word_importance_chunked(self, PROMPT):
pass
class ImportanceEvaluatorStatic(IE):
def __init__(self):
# Download from the Hub
self.CLIP_MODEL_ID = "sentence-transformers/static-retrieval-mrl-en-v1"
self.model = SentenceTransformer(self.CLIP_MODEL_ID)
def get_word_importance(self, PROMPT):
words, omit_prompts = create_splits(PROMPT)
sentences = [PROMPT] + omit_prompts
embeddings = self.model.encode(sentences)
similarities = self.model.similarity(embeddings[0:1], embeddings)
x = similarities[0]
x = -x.log() # importance of a word is the inverse of similarity-when-dropped
x = x - x[0] # subtract self-similarity as the baseline
x = x.clamp(0)
x /= x.max()
return x[1:], words
def get_word_importance_chunked(self, PROMPT):
return self.get_word_importance(PROMPT)
# %%
def compute_static_word_importances(
f: Path, ie: ImportanceEvaluatorStatic, overwrite=False
):
model_id = ie.CLIP_MODEL_ID
for c in f.glob(".captions/*"):
metadir = c / ".meta"
for file in c.iterdir():
if file.suffix == ".txt" and file.is_file():
# print(file)
try:
out = metadir / file.with_suffix(".pth").name
r = {}
if out.exists():
r = torch.load(out, weights_only=False)
assert isinstance(r, dict), "corrupt format"
if (not overwrite) and (model_id in r):
continue
caption = file.read_text()
if (model_id not in r) or overwrite:
importances = [
ie.get_word_importance_chunked(l) if l else None
for l in caption.split("\n")
]
r[model_id] = importances
metadir.mkdir(exist_ok=True)
torch.save(r, out)
except Exception as e:
print("ERROR", out, e)
def yield_dirs(root: Path):
for subset in root.iterdir():
if not subset.is_dir():
if subset.name.startswith("."):
continue
yield subset
if __name__ == "__main__":
ies = ImportanceEvaluatorStatic()
root = Path("/path_to_my_files")
dfs = []
from tqdm import tqdm
pb = tqdm()
for f in yield_dirs(root, True):
pb.update(1)
print(f)
compute_static_word_importances(f, ies, overwrite=False)