Soumalya Das commited on
Commit
2afe572
·
verified ·
1 Parent(s): 9012408

Update api.py

Browse files
Files changed (1) hide show
  1. api.py +4 -2
api.py CHANGED
@@ -80,8 +80,10 @@ class MovieRecommender:
80
  return np.array(ids)[None,:]
81
 
82
  def recommend(self, prompt, topk=10):
83
- q_ids = self._encode(prompt)
84
- query_vec = np.sum(self.embeddings[q_ids], axis=1)
 
 
85
  sims = cosine_similarity(query_vec, self.embeddings).flatten()
86
  idx = sims.argsort()[::-1][:topk]
87
  return self.movies.iloc[idx][["title","release_date","vote_average","vote_count","status"]]
 
80
  return np.array(ids)[None,:]
81
 
82
  def recommend(self, prompt, topk=10):
83
+ q_ids = self.tokenizer.texts_to_sequences([prompt])[0]
84
+ q_ids = [i for i in q_ids if 0 <= i < len(self.embeddings)]
85
+ q_ids = np.array(q_ids, dtype=np.int64)
86
+ query_vec = self.embeddings[q_ids].mean(axis=0, keepdims=True)
87
  sims = cosine_similarity(query_vec, self.embeddings).flatten()
88
  idx = sims.argsort()[::-1][:topk]
89
  return self.movies.iloc[idx][["title","release_date","vote_average","vote_count","status"]]