File size: 1,755 Bytes
721ccc4
 
fdd9cc1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
721ccc4
 
 
 
 
 
 
 
 
 
 
924b1cf
 
 
721ccc4
 
42f840f
924b1cf
721ccc4
 
924b1cf
721ccc4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fdd9cc1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
---
license: mit
datasets:
- michael-0acf4/anitag2vec-data-reference
language:
- en
- ja
- ko
- zh
pipeline_tag: sentence-similarity
tags:
- embedding
- encoding
- similarity
- vector
- danbooru
- sakugabooru
- pixiv
- myanimelist
---

# AniTag2Vec

Generate vector embeddings from Danbooru, Sakugabooru, Pixiv, MAL style tags.

Training and inference examples are all available on [my github](https://github.com/michael-0acf4/anitag2vec).

Implementation is detailed in [this blog post](https://blog.afmichael.dev/posts/2026/set-embeddings-and-anitag2vec/).

```python
TOKENIZER_PATH = "./pytorch/token_dataset_b0d065e705028cb3_vocab_size_5000_freq_3.json"
CONFIG_PATH = "./pytorch/config_63fc21b89723d1ce_b0d065e705028cb3.json"
MODEL_PATH = "./pytorch/anitag2vec_63fc21b89723d1ce_b0d065e705028cb3_i128_e30_s157043_b256_p1871744.pth"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
cfg = ModelConfig.load_from_file(CONFIG_PATH)
tagtok = TagBPETokenizer.load_from_file(TOKENIZER_PATH)

anitag2vec = AniTag2Vec(
    vocab_size=cfg.HYPERP_TAGTOK_VOCAB_SIZE,
    max_len_cut=cfg.HYPERP_TAGTOK_MAX_TOKEN_CLAMP,
    d_model=cfg.HYPERP_TRANSFORMER_D_MODEL,
    n_heads=cfg.HYPERP_TRANSFORMER_N_HEADS,
    n_layers=cfg.HYPERP_TRANSFORMER_N_LAYERS,
    output_emb=cfg.HYPERP_OUTPUT_EMB,
)
anitag2vec.to(device)
anitag2vec.load_state_dict(torch.load(MODEL_PATH))
anitag2vec.eval()
runner = AniTag2VecRunner(tagtok, anitag2vec)

# Inference
def compare(a: str, b: str):
    ax = runner.run_inference_human([a])
    bx = runner.run_inference_human([b])
    howmuch = ((F.normalize(ax) @ F.normalize(bx).T).item())
    print(f"{howmuch:.2f}: '{a}' vs '{b}'")

compare("#1girl #1boy", "#1boy #1girl")
# 1.00: '#1girl #1boy' vs '#1boy #1girl'
```