# %% from pathlib import Path import torch from sentence_transformers import SentenceTransformer # %% def create_splits(p): # Create prompts with each word omitted words = p.split() omit_prompts = [ " ".join(w for i, w in enumerate(words) if i != j) for j in range(len(words)) ] return words, omit_prompts # %% from abc import ABC, abstractmethod class IE(ABC): @abstractmethod def get_word_importance_chunked(self, PROMPT): pass class ImportanceEvaluatorStatic(IE): def __init__(self): # Download from the Hub self.CLIP_MODEL_ID = "sentence-transformers/static-retrieval-mrl-en-v1" self.model = SentenceTransformer(self.CLIP_MODEL_ID) def get_word_importance(self, PROMPT): words, omit_prompts = create_splits(PROMPT) sentences = [PROMPT] + omit_prompts embeddings = self.model.encode(sentences) similarities = self.model.similarity(embeddings[0:1], embeddings) x = similarities[0] x = -x.log() # importance of a word is the inverse of similarity-when-dropped x = x - x[0] # subtract self-similarity as the baseline x = x.clamp(0) x /= x.max() return x[1:], words def get_word_importance_chunked(self, PROMPT): return self.get_word_importance(PROMPT) # %% def compute_static_word_importances( f: Path, ie: ImportanceEvaluatorStatic, overwrite=False ): model_id = ie.CLIP_MODEL_ID for c in f.glob(".captions/*"): metadir = c / ".meta" for file in c.iterdir(): if file.suffix == ".txt" and file.is_file(): # print(file) try: out = metadir / file.with_suffix(".pth").name r = {} if out.exists(): r = torch.load(out, weights_only=False) assert isinstance(r, dict), "corrupt format" if (not overwrite) and (model_id in r): continue caption = file.read_text() if (model_id not in r) or overwrite: importances = [ ie.get_word_importance_chunked(l) if l else None for l in caption.split("\n") ] r[model_id] = importances metadir.mkdir(exist_ok=True) torch.save(r, out) except Exception as e: print("ERROR", out, e) def yield_dirs(root: Path): for subset in root.iterdir(): if not subset.is_dir(): if subset.name.startswith("."): continue yield subset if __name__ == "__main__": ies = ImportanceEvaluatorStatic() root = Path("/path_to_my_files") dfs = [] from tqdm import tqdm pb = tqdm() for f in yield_dirs(root, True): pb.update(1) print(f) compute_static_word_importances(f, ies, overwrite=False)