Spaces:
Running
Running
File size: 3,028 Bytes
4d939fc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
"""Compute similarity matrices from embedding databases."""
from __future__ import annotations
import argparse
from pathlib import Path
from typing import Iterable, Optional
import matplotlib.pyplot as plt
import torch
def gen_data(dict_data):
embeddings = dict_data["embeddings"]
labels = dict_data["labels"]
ids = dict_data["ids"]
classes = dict_data["classes"]
return embeddings, labels, ids, classes
def build_argument_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(
description="Generate class similarity matrices for DETree.",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument("--database", type=Path, required=True, help="Path to the embedding database (.pt).")
parser.add_argument("--output-dir", type=Path, required=True, help="Directory to store the similarity outputs.")
parser.add_argument("--layers", type=int, nargs="*", default=None, help="Specific layers to export. Defaults to all.")
return parser
def compute_similarity(database: Path, output_dir: Path, layers: Optional[Iterable[int]]) -> None:
output_dir.mkdir(parents=True, exist_ok=True)
data_emb, data_labels, data_ids, data_classes = gen_data(torch.load(database))
if layers is None:
layers = list(data_emb.keys())
for layer in layers:
center = []
for item in data_classes:
index = data_classes.index(item)
now_emb = data_emb[layer][data_labels == index]
center.append(torch.mean(now_emb, dim=0))
center = torch.stack(center)
similarity = center @ center.T
similarity_np = similarity.cpu().float().numpy()
txt_path = output_dir / f"similarity_layer_{layer}.txt"
with txt_path.open("w", encoding="utf-8") as f:
f.write(" ".join(data_classes) + "\n")
for i, class_name in enumerate(data_classes):
row = " ".join(f"{similarity_np[i, j]:.4f}" for j in range(len(data_classes)))
f.write(f"{class_name} {row}\n")
plt.figure(figsize=(30, 30))
plt.imshow(similarity_np, cmap="viridis")
plt.colorbar()
plt.xticks(range(len(data_classes)), data_classes, rotation=45, fontsize=12)
plt.yticks(range(len(data_classes)), data_classes, fontsize=12)
plt.title(f"Similarity Matrix (layer {layer})", fontsize=20)
fig_path = output_dir / f"similarity_layer_{layer}.png"
plt.savefig(fig_path, dpi=300, bbox_inches="tight")
plt.close()
print(f"Saved similarity matrix for layer {layer} to {txt_path} and {fig_path}")
def main(argv: Optional[Iterable[str]] = None) -> None:
parser = build_argument_parser()
args = parser.parse_args(argv)
compute_similarity(args.database, args.output_dir, args.layers)
if __name__ == "__main__":
main()
__all__ = ["build_argument_parser", "compute_similarity", "main"]
|