Soumalya Das
commited on
Update api.py
Browse files
api.py
CHANGED
|
@@ -80,8 +80,10 @@ class MovieRecommender:
|
|
| 80 |
return np.array(ids)[None,:]
|
| 81 |
|
| 82 |
def recommend(self, prompt, topk=10):
|
| 83 |
-
q_ids = self.
|
| 84 |
-
|
|
|
|
|
|
|
| 85 |
sims = cosine_similarity(query_vec, self.embeddings).flatten()
|
| 86 |
idx = sims.argsort()[::-1][:topk]
|
| 87 |
return self.movies.iloc[idx][["title","release_date","vote_average","vote_count","status"]]
|
|
|
|
| 80 |
return np.array(ids)[None,:]
|
| 81 |
|
| 82 |
def recommend(self, prompt, topk=10):
|
| 83 |
+
q_ids = self.tokenizer.texts_to_sequences([prompt])[0]
|
| 84 |
+
q_ids = [i for i in q_ids if 0 <= i < len(self.embeddings)]
|
| 85 |
+
q_ids = np.array(q_ids, dtype=np.int64)
|
| 86 |
+
query_vec = self.embeddings[q_ids].mean(axis=0, keepdims=True)
|
| 87 |
sims = cosine_similarity(query_vec, self.embeddings).flatten()
|
| 88 |
idx = sims.argsort()[::-1][:topk]
|
| 89 |
return self.movies.iloc[idx][["title","release_date","vote_average","vote_count","status"]]
|