feat: Upload anime2vec v12 model and artifacts
Browse files- README.md +40 -7
- config.json +19 -0
- le_genre.pkl +3 -0
- le_studio.pkl +3 -0
- le_type.pkl +3 -0
- model.py +70 -0
- pytorch_model.bin +3 -0
README.md
CHANGED
|
@@ -4,12 +4,45 @@ language:
|
|
| 4 |
- en
|
| 5 |
- uk
|
| 6 |
- ja
|
| 7 |
-
library_name: pytorch
|
| 8 |
tags:
|
| 9 |
- anime
|
| 10 |
-
-
|
| 11 |
-
-
|
| 12 |
-
-
|
| 13 |
-
-
|
| 14 |
-
|
| 15 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
- en
|
| 5 |
- uk
|
| 6 |
- ja
|
|
|
|
| 7 |
tags:
|
| 8 |
- anime
|
| 9 |
+
- embeddings
|
| 10 |
+
- semantic-search
|
| 11 |
+
- vector-arithmetic
|
| 12 |
+
- pytorch
|
| 13 |
+
datasets:
|
| 14 |
+
- private
|
| 15 |
+
author: Lorg0n
|
| 16 |
+
---
|
| 17 |
+
|
| 18 |
+
# hikka-forge-anime2vec
|
| 19 |
+
|
| 20 |
+
This repository contains `hikka-forge-anime2vec`, a sophisticated semantic vector space model for anime, created by [Lorg0n](https://huggingface.co/Lorg0n).
|
| 21 |
+
|
| 22 |
+
The model is trained to understand deep connections between titles based on multilingual textual descriptions, genres, studios, and other metadata. It supports vector arithmetic, allowing for creative queries like `"Show me something like 'Spirited Away' - 'Ghibli Style' + 'Cyberpunk'"`.
|
| 23 |
+
|
| 24 |
+
## Model Details
|
| 25 |
+
|
| 26 |
+
- **Model Version**: v12
|
| 27 |
+
- **Architecture**: A multi-input neural network with separate processing streams for text, genres, and other categorical/numerical features. It uses attention mechanisms to weigh the importance of different text fields and genres.
|
| 28 |
+
- **Training**: Trained using a combination of Triplet Loss (from explicit user recommendations), Cosine Similarity Loss for vector arithmetic examples, and a Diversity Loss to ensure a well-distributed embedding space.
|
| 29 |
+
- **Data**: Trained on a private, non-public database of anime titles.
|
| 30 |
+
|
| 31 |
+
## How to Use
|
| 32 |
+
|
| 33 |
+
*This model requires custom code for loading and inference due to its unique architecture and preprocessing steps.*
|
| 34 |
+
|
| 35 |
+
A full usage example will be provided soon. The general workflow involves:
|
| 36 |
+
1. Loading the model, config, and pickled `LabelEncoder` objects.
|
| 37 |
+
2. Preprocessing new anime data (fetching from a data source, encoding text with a SentenceTransformer, etc.).
|
| 38 |
+
3. Using the model to generate a 512-dimensional embedding.
|
| 39 |
+
4. Performing similarity search or vector arithmetic in the embedding space.
|
| 40 |
+
|
| 41 |
+
## Files in this Repository
|
| 42 |
+
|
| 43 |
+
This repository contains all files necessary for model inference:
|
| 44 |
+
|
| 45 |
+
- `pytorch_model.bin`: The trained model weights.
|
| 46 |
+
- `config.json`: Configuration file specifying model architecture and vocabulary sizes.
|
| 47 |
+
- `model.py`: The Python code defining the `AnimeEmbeddingModel` class.
|
| 48 |
+
- `le_genre.pkl`, `le_studio.pkl`, `le_type.pkl`: Pickled Scikit-learn `LabelEncoder` objects required for preprocessing new data.
|
config.json
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architecture": "AnimeEmbeddingModel",
|
| 3 |
+
"vocab_sizes": {
|
| 4 |
+
"genre": 77,
|
| 5 |
+
"studio": 1170,
|
| 6 |
+
"type": 7
|
| 7 |
+
},
|
| 8 |
+
"embedding_dims": {
|
| 9 |
+
"genre": 128,
|
| 10 |
+
"studio": 64,
|
| 11 |
+
"type": 16,
|
| 12 |
+
"numerical": 32,
|
| 13 |
+
"text": 384
|
| 14 |
+
},
|
| 15 |
+
"text_embedding_size": 384,
|
| 16 |
+
"model_type": "hikka-forge-anime2vec",
|
| 17 |
+
"model_version": "v12",
|
| 18 |
+
"author": "Lorg0n"
|
| 19 |
+
}
|
le_genre.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5b281e8d281d95b8a3681c145b23c74dd916949da2b881697c3053927d00bb8b
|
| 3 |
+
size 5476
|
le_studio.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:82be27d7b828b66ed776e59f9e57093e461359eb7b6d783cecbb45697c905785
|
| 3 |
+
size 229570
|
le_type.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:11535e48ef762e4f790f0013a47477c258a3a6dd0cfaa45dfc0e48557a9411ab
|
| 3 |
+
size 432
|
model.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from typing import Dict
|
| 5 |
+
|
| 6 |
+
class TextFieldAttention(nn.Module):
|
| 7 |
+
def __init__(self, num_fields: int, field_dim: int):
|
| 8 |
+
super().__init__()
|
| 9 |
+
self.attn = nn.Linear(field_dim, 1, bias=False)
|
| 10 |
+
self.num_fields = num_fields
|
| 11 |
+
def forward(self, fields: torch.Tensor):
|
| 12 |
+
scores = self.attn(fields)
|
| 13 |
+
weights = F.softmax(scores, dim=1)
|
| 14 |
+
weighted_sum = (fields * weights).sum(dim=1)
|
| 15 |
+
return weighted_sum, weights.squeeze(-1)
|
| 16 |
+
|
| 17 |
+
class GenreSelfAttention(nn.Module):
|
| 18 |
+
def __init__(self, genre_dim: int):
|
| 19 |
+
super().__init__()
|
| 20 |
+
self.attn_scorer = nn.Linear(genre_dim, 1, bias=False)
|
| 21 |
+
def forward(self, genre_embeds: torch.Tensor, mask: torch.Tensor):
|
| 22 |
+
scores = self.attn_scorer(genre_embeds)
|
| 23 |
+
scores.masked_fill_(mask == 0, -1e9)
|
| 24 |
+
weights = F.softmax(scores, dim=1)
|
| 25 |
+
weighted_sum = (genre_embeds * weights).sum(dim=1)
|
| 26 |
+
return weighted_sum
|
| 27 |
+
|
| 28 |
+
class AnimeEmbeddingModel(nn.Module):
|
| 29 |
+
def __init__(self, vocab_sizes: Dict[str, int], embedding_dims: Dict[str, int], dropout_rate: float = 0.3, text_embedding_size: int = 384):
|
| 30 |
+
super().__init__()
|
| 31 |
+
self.embedding_dims = embedding_dims
|
| 32 |
+
self.genre_embedding = nn.Embedding(vocab_sizes['genre'], embedding_dims['genre'], padding_idx=0)
|
| 33 |
+
self.studio_embedding = nn.Embedding(vocab_sizes['studio'], embedding_dims['studio'])
|
| 34 |
+
self.type_embedding = nn.Embedding(vocab_sizes['type'], embedding_dims['type'])
|
| 35 |
+
self.numerical_layer = nn.Linear(6, embedding_dims['numerical'])
|
| 36 |
+
self.text_field_attention = TextFieldAttention(num_fields=6, field_dim=text_embedding_size)
|
| 37 |
+
self.genre_attention = GenreSelfAttention(embedding_dims['genre'])
|
| 38 |
+
total_dim = sum(embedding_dims.values())
|
| 39 |
+
self.encoder = nn.Sequential(
|
| 40 |
+
nn.Linear(total_dim, 1024), nn.ReLU(), nn.Dropout(dropout_rate), nn.LayerNorm(1024),
|
| 41 |
+
nn.Linear(1024, 768), nn.ReLU(), nn.Dropout(dropout_rate), nn.LayerNorm(768),
|
| 42 |
+
nn.Linear(768, 512),
|
| 43 |
+
)
|
| 44 |
+
self.text_scale = nn.Parameter(torch.tensor(1.0))
|
| 45 |
+
self.genre_scale = nn.Parameter(torch.tensor(1.0))
|
| 46 |
+
self.other_scale = nn.Parameter(torch.tensor(1.0))
|
| 47 |
+
def forward(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
|
| 48 |
+
text_fields = torch.stack([
|
| 49 |
+
batch['precomputed_ua_desc'], batch['precomputed_en_desc'],
|
| 50 |
+
batch['precomputed_ua_title'], batch['precomputed_en_title'],
|
| 51 |
+
batch['precomputed_original_title'], batch['precomputed_alternate_names'],
|
| 52 |
+
], dim=1)
|
| 53 |
+
text_vector, _ = self.text_field_attention(text_fields)
|
| 54 |
+
genre_embeds_raw = self.genre_embedding(batch['genres'])
|
| 55 |
+
genre_vector = self.genre_attention(genre_embeds_raw, batch['genres_mask'].unsqueeze(-1))
|
| 56 |
+
studio_emb = self.studio_embedding(batch['studio'])
|
| 57 |
+
type_emb = self.type_embedding(batch['type'])
|
| 58 |
+
numerical_emb = F.relu(self.numerical_layer(batch['numerical']))
|
| 59 |
+
other_vector_parts = torch.cat([studio_emb, type_emb, numerical_emb], dim=1)
|
| 60 |
+
text_vector_norm = F.normalize(text_vector, p=2, dim=1)
|
| 61 |
+
genre_vector_norm = F.normalize(genre_vector, p=2, dim=1)
|
| 62 |
+
other_vector_norm = F.normalize(other_vector_parts, p=2, dim=1)
|
| 63 |
+
scaled_text = text_vector_norm * self.text_scale
|
| 64 |
+
scaled_genre = genre_vector_norm * self.genre_scale
|
| 65 |
+
scaled_other = other_vector_norm * self.other_scale
|
| 66 |
+
combined = torch.cat([scaled_text, scaled_genre, scaled_other], dim=1)
|
| 67 |
+
embedding_logits = self.encoder(combined)
|
| 68 |
+
embedding = torch.tanh(embedding_logits)
|
| 69 |
+
embedding = F.normalize(embedding, p=2, dim=1)
|
| 70 |
+
return embedding
|
pytorch_model.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:208dab2e050ff92356d680320ad10631afa5f5ff9c7ba605f519a944d15c3690
|
| 3 |
+
size 7647786
|