Cyantist8208 commited on
Commit
905b1f8
·
1 Parent(s): db12d3a
Files changed (1) hide show
  1. app.py +3 -3
app.py CHANGED
@@ -123,10 +123,10 @@ def answer(system: str, context: str, question: str, user_id="demo", history="No
123
  q_vec = embed(question).view(-1).cpu()
124
  store = kb[user_id]
125
  sims = torch.matmul(store["vecs"], q_vec) # [N]
126
- if sims.numel() == 0:
127
- return "Knowledge base is empty or corrupted."
128
  k = min(4, store["vecs"].size(0))
129
- idxs = torch.topk(sims, k=k).indices.tolist()
130
  context_list += [store["texts"][i] for i in idxs]
131
  elif history == "All":
132
  store = kb[user_id]
 
123
  q_vec = embed(question).view(-1).cpu()
124
  store = kb[user_id]
125
  sims = torch.matmul(store["vecs"], q_vec) # [N]
126
+ if sims.shape[1] == 1:
127
+ sims = sims.squeeze(1)
128
  k = min(4, store["vecs"].size(0))
129
+ idxs = torch.topk(sims, k=k, dim=0).indices.tolist()
130
  context_list += [store["texts"][i] for i in idxs]
131
  elif history == "All":
132
  store = kb[user_id]