File size: 3,641 Bytes
b0986f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2afe572
 
 
 
b0986f4
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import numpy as np
import pandas as pd
import json
import pickle
import io
from sklearn.metrics.pairwise import cosine_similarity

class MovieRecommender:
    def __init__(self, model_path="."):
        self.embeddings = np.load(f"{model_path}/embeddings.npy")
        self.embeddings = np.nan_to_num(self.embeddings)
        
        # Try loading from JSON first (preferred)
        try:
            with open(f"{model_path}/tokenizer_vocab.json", "r") as f:
                self.tokenizer = json.load(f)
        except FileNotFoundError:
            # Fallback: extract vocab from pickle file using BytesIO
            self.tokenizer = self._extract_vocab_from_pickle(f"{model_path}/tokenizer.pkl")
            # Save as JSON for future use
            with open(f"{model_path}/tokenizer_vocab.json", "w") as f:
                json.dump(self.tokenizer, f)
        
        self.movies = pd.read_json(f"{model_path}/movies.json")
    
    def _extract_vocab_from_pickle(self, filepath):
        """Extract vocabulary dictionary from pickle file by analyzing its structure"""
        with open(filepath, "rb") as f:
            pickle_data = f.read()
        
        # Try to find dict-like structures in the pickle
        try:
            # Use pickletools to analyze and reconstruct
            unpickler = pickle.Unpickler(io.BytesIO(pickle_data))
            # Disable loading of classes that don't exist
            unpickler.find_class = lambda module, name: dict
            try:
                result = unpickler.load()
                if isinstance(result, dict):
                    return result
            except:
                pass
        except:
            pass
        
        # Fallback: scan for dictionary patterns in pickle bytecode
        try:
            memo = {}
            stack = []
            
            # Read pickle opcodes manually
            import pickletools
            ops = []
            for opcode, arg, pos in pickletools.genops(pickle_data):
                ops.append((opcode.name, arg))
            
            # Look for dictionary-like structures
            for i, (op, arg) in enumerate(ops):
                if op == 'EMPTY_DICT' or op == 'DICT':
                    # Found a dict operation
                    try:
                        # Try to reconstruct from this point
                        subset = pickle_data[:pos+10] # pyright: ignore[reportOptionalOperand]
                        test_unpickler = pickle.Unpickler(io.BytesIO(subset))
                        test_unpickler.find_class = lambda m, n: None
                    except:
                        pass
        except:
            pass
        
        # Final fallback: return empty dict
        print("Warning: Could not extract vocabulary from pickle. Using empty tokenizer.")
        print("Recommendation quality will be limited.")
        return {}

    def _encode(self, prompt):
        tokens = prompt.lower().split()[:32]
        ids = [self.tokenizer.get(t, 0) for t in tokens]
        ids = [i if i < len(self.embeddings) else 0 for i in ids]
        return np.array(ids)[None,:]

    def recommend(self, prompt, topk=10):
        q_ids = self.tokenizer.texts_to_sequences([prompt])[0]
        q_ids = [i for i in q_ids if 0 <= i < len(self.embeddings)]
        q_ids = np.array(q_ids, dtype=np.int64)
        query_vec = self.embeddings[q_ids].mean(axis=0, keepdims=True)
        sims = cosine_similarity(query_vec, self.embeddings).flatten()
        idx = sims.argsort()[::-1][:topk]
        return self.movies.iloc[idx][["title","release_date","vote_average","vote_count","status"]]