File size: 4,985 Bytes
24aa8bd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import torch
import pickle
import json
from huggingface_hub import hf_hub_download
from sentence_transformers import SentenceTransformer
import numpy as np
import importlib.util
from pathlib import Path

class Anime2Vec:
    """
    A high-level wrapper to easily use the hikka-forge-anime2vec model.
    It automatically downloads all required artifacts from the Hugging Face Hub.
    """
    def __init__(self, repo_id: str = "Lorg0n/hikka-forge-anime2vec", device: str = None):
        print(f"🚀 Initializing Anime2Vec from repository: {repo_id}")
        
        self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu")
        print(f"   - Using device: {self.device}")
        
        cache_dir = Path.home() / ".cache" / "hikka-forge"
        
        # Download all necessary files from the repo root
        config_path = hf_hub_download(repo_id=repo_id, filename="config.json", cache_dir=cache_dir)
        model_path = hf_hub_download(repo_id=repo_id, filename="pytorch_model.bin", cache_dir=cache_dir)
        model_code_path = hf_hub_download(repo_id=repo_id, filename="model.py", cache_dir=cache_dir)
        le_genre_path = hf_hub_download(repo_id=repo_id, filename="le_genre.pkl", cache_dir=cache_dir)
        le_studio_path = hf_hub_download(repo_id=repo_id, filename="le_studio.pkl", cache_dir=cache_dir)
        le_type_path = hf_hub_download(repo_id=repo_id, filename="le_type.pkl", cache_dir=cache_dir)
        
        # Load configuration and encoders
        with open(config_path, 'r') as f:
            self.config = json.load(f)
        with open(le_genre_path, 'rb') as f:
            self.le_genre = pickle.load(f)
        with open(le_studio_path, 'rb') as f:
            self.le_studio = pickle.load(f)
        with open(le_type_path, 'rb') as f:
            self.le_type = pickle.load(f)

        # Dynamically import the model class from the downloaded model.py
        spec = importlib.util.spec_from_file_location("AnimeEmbeddingModel", model_code_path)
        model_module = importlib.util.module_from_spec(spec)
        spec.loader.exec_module(model_module)
        AnimeEmbeddingModel = model_module.AnimeEmbeddingModel
        
        # Initialize the model and load its weights
        self.model = AnimeEmbeddingModel(
            vocab_sizes=self.config['vocab_sizes'],
            embedding_dims=self.config['embedding_dims'],
            text_embedding_size=self.config['text_embedding_size']
        )
        self.model.load_state_dict(torch.load(model_path, map_location=self.device))
        self.model.to(self.device)
        self.model.eval()
        
        # Initialize the text encoder
        self.text_encoder = SentenceTransformer(
            'Lorg0n/hikka-forge-paraphrase-multilingual-MiniLM-L12-v2',
            device=self.device
        )
        print("✅ Initialization complete. Model is ready to use.")

    @torch.no_grad()
    def encode(self, anime_data: dict) -> np.ndarray:
        """
        Encodes a dictionary of anime data into a 512-dimensional vector.
        """
        text_fields = [
            anime_data.get('ua_description', ''), anime_data.get('en_description', ''),
            anime_data.get('ua_title', ''), anime_data.get('en_title', ''),
            anime_data.get('original_title', ''), "; ".join(anime_data.get('alternate_names', []))
        ]
        text_embeddings = self.text_encoder.encode(text_fields, convert_to_tensor=True)
        
        known_genres = [g for g in anime_data.get('genres', []) if g in self.le_genre.classes_]
        genre_ids = self.le_genre.transform(known_genres) if known_genres else [0]
        
        try:
            studio_id = self.le_studio.transform([anime_data.get('studio', 'UNKNOWN')])[0]
        except ValueError:
            studio_id = self.le_studio.transform(['UNKNOWN'])[0]

        try:
            type_id = self.le_type.transform([anime_data.get('type', 'UNKNOWN')])[0]
        except ValueError:
            type_id = self.le_type.transform(['UNKNOWN'])[0]
            
        numerical = torch.tensor(anime_data.get('numerical_features', [0.0]*6), dtype=torch.float32)

        batch = {
            'precomputed_ua_desc': text_embeddings[0], 'precomputed_en_desc': text_embeddings[1],
            'precomputed_ua_title': text_embeddings[2], 'precomputed_en_title': text_embeddings[3],
            'precomputed_original_title': text_embeddings[4], 'precomputed_alternate_names': text_embeddings[5],
            'genres': torch.tensor(genre_ids, dtype=torch.long),
            'studio': torch.tensor(studio_id, dtype=torch.long),
            'type': torch.tensor(type_id, dtype=torch.long),
            'numerical': numerical
        }

        for key, tensor in batch.items():
            batch[key] = tensor.unsqueeze(0).to(self.device)
        batch['genres_mask'] = (batch['genres'] != 0).long()
        
        embedding = self.model(batch)
        return embedding.squeeze().cpu().numpy()