File size: 3,028 Bytes
4d939fc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
"""Compute similarity matrices from embedding databases."""

from __future__ import annotations

import argparse
from pathlib import Path
from typing import Iterable, Optional

import matplotlib.pyplot as plt
import torch


def gen_data(dict_data):
    embeddings = dict_data["embeddings"]
    labels = dict_data["labels"]
    ids = dict_data["ids"]
    classes = dict_data["classes"]
    return embeddings, labels, ids, classes


def build_argument_parser() -> argparse.ArgumentParser:
    parser = argparse.ArgumentParser(
        description="Generate class similarity matrices for DETree.",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    parser.add_argument("--database", type=Path, required=True, help="Path to the embedding database (.pt).")
    parser.add_argument("--output-dir", type=Path, required=True, help="Directory to store the similarity outputs.")
    parser.add_argument("--layers", type=int, nargs="*", default=None, help="Specific layers to export. Defaults to all.")
    return parser


def compute_similarity(database: Path, output_dir: Path, layers: Optional[Iterable[int]]) -> None:
    output_dir.mkdir(parents=True, exist_ok=True)
    data_emb, data_labels, data_ids, data_classes = gen_data(torch.load(database))

    if layers is None:
        layers = list(data_emb.keys())

    for layer in layers:
        center = []
        for item in data_classes:
            index = data_classes.index(item)
            now_emb = data_emb[layer][data_labels == index]
            center.append(torch.mean(now_emb, dim=0))
        center = torch.stack(center)
        similarity = center @ center.T
        similarity_np = similarity.cpu().float().numpy()

        txt_path = output_dir / f"similarity_layer_{layer}.txt"
        with txt_path.open("w", encoding="utf-8") as f:
            f.write(" ".join(data_classes) + "\n")
            for i, class_name in enumerate(data_classes):
                row = " ".join(f"{similarity_np[i, j]:.4f}" for j in range(len(data_classes)))
                f.write(f"{class_name} {row}\n")

        plt.figure(figsize=(30, 30))
        plt.imshow(similarity_np, cmap="viridis")
        plt.colorbar()
        plt.xticks(range(len(data_classes)), data_classes, rotation=45, fontsize=12)
        plt.yticks(range(len(data_classes)), data_classes, fontsize=12)
        plt.title(f"Similarity Matrix (layer {layer})", fontsize=20)
        fig_path = output_dir / f"similarity_layer_{layer}.png"
        plt.savefig(fig_path, dpi=300, bbox_inches="tight")
        plt.close()
        print(f"Saved similarity matrix for layer {layer} to {txt_path} and {fig_path}")


def main(argv: Optional[Iterable[str]] = None) -> None:
    parser = build_argument_parser()
    args = parser.parse_args(argv)
    compute_similarity(args.database, args.output_dir, args.layers)


if __name__ == "__main__":
    main()

__all__ = ["build_argument_parser", "compute_similarity", "main"]