Spaces:
Sleeping
Sleeping
| # %% | |
| from pathlib import Path | |
| import torch | |
| from sentence_transformers import SentenceTransformer | |
| # %% | |
| def create_splits(p): | |
| # Create prompts with each word omitted | |
| words = p.split() | |
| omit_prompts = [ | |
| " ".join(w for i, w in enumerate(words) if i != j) for j in range(len(words)) | |
| ] | |
| return words, omit_prompts | |
| # %% | |
| from abc import ABC, abstractmethod | |
| class IE(ABC): | |
| def get_word_importance_chunked(self, PROMPT): | |
| pass | |
| class ImportanceEvaluatorStatic(IE): | |
| def __init__(self): | |
| # Download from the Hub | |
| self.CLIP_MODEL_ID = "sentence-transformers/static-retrieval-mrl-en-v1" | |
| self.model = SentenceTransformer(self.CLIP_MODEL_ID) | |
| def get_word_importance(self, PROMPT): | |
| words, omit_prompts = create_splits(PROMPT) | |
| sentences = [PROMPT] + omit_prompts | |
| embeddings = self.model.encode(sentences) | |
| similarities = self.model.similarity(embeddings[0:1], embeddings) | |
| x = similarities[0] | |
| x = -x.log() # importance of a word is the inverse of similarity-when-dropped | |
| x = x - x[0] # subtract self-similarity as the baseline | |
| x = x.clamp(0) | |
| x /= x.max() | |
| return x[1:], words | |
| def get_word_importance_chunked(self, PROMPT): | |
| return self.get_word_importance(PROMPT) | |
| # %% | |
| def compute_static_word_importances( | |
| f: Path, ie: ImportanceEvaluatorStatic, overwrite=False | |
| ): | |
| model_id = ie.CLIP_MODEL_ID | |
| for c in f.glob(".captions/*"): | |
| metadir = c / ".meta" | |
| for file in c.iterdir(): | |
| if file.suffix == ".txt" and file.is_file(): | |
| # print(file) | |
| try: | |
| out = metadir / file.with_suffix(".pth").name | |
| r = {} | |
| if out.exists(): | |
| r = torch.load(out, weights_only=False) | |
| assert isinstance(r, dict), "corrupt format" | |
| if (not overwrite) and (model_id in r): | |
| continue | |
| caption = file.read_text() | |
| if (model_id not in r) or overwrite: | |
| importances = [ | |
| ie.get_word_importance_chunked(l) if l else None | |
| for l in caption.split("\n") | |
| ] | |
| r[model_id] = importances | |
| metadir.mkdir(exist_ok=True) | |
| torch.save(r, out) | |
| except Exception as e: | |
| print("ERROR", out, e) | |
| def yield_dirs(root: Path): | |
| for subset in root.iterdir(): | |
| if not subset.is_dir(): | |
| if subset.name.startswith("."): | |
| continue | |
| yield subset | |
| if __name__ == "__main__": | |
| ies = ImportanceEvaluatorStatic() | |
| root = Path("/path_to_my_files") | |
| dfs = [] | |
| from tqdm import tqdm | |
| pb = tqdm() | |
| for f in yield_dirs(root, True): | |
| pb.update(1) | |
| print(f) | |
| compute_static_word_importances(f, ies, overwrite=False) | |