Spaces:
Sleeping
Sleeping
Commit ·
905b1f8
1
Parent(s): db12d3a
fix
Browse files
app.py
CHANGED
|
@@ -123,10 +123,10 @@ def answer(system: str, context: str, question: str, user_id="demo", history="No
|
|
| 123 |
q_vec = embed(question).view(-1).cpu()
|
| 124 |
store = kb[user_id]
|
| 125 |
sims = torch.matmul(store["vecs"], q_vec) # [N]
|
| 126 |
-
if sims.
|
| 127 |
-
|
| 128 |
k = min(4, store["vecs"].size(0))
|
| 129 |
-
idxs = torch.topk(sims, k=k).indices.tolist()
|
| 130 |
context_list += [store["texts"][i] for i in idxs]
|
| 131 |
elif history == "All":
|
| 132 |
store = kb[user_id]
|
|
|
|
| 123 |
q_vec = embed(question).view(-1).cpu()
|
| 124 |
store = kb[user_id]
|
| 125 |
sims = torch.matmul(store["vecs"], q_vec) # [N]
|
| 126 |
+
if sims.shape[1] == 1:
|
| 127 |
+
sims = sims.squeeze(1)
|
| 128 |
k = min(4, store["vecs"].size(0))
|
| 129 |
+
idxs = torch.topk(sims, k=k, dim=0).indices.tolist()
|
| 130 |
context_list += [store["texts"][i] for i in idxs]
|
| 131 |
elif history == "All":
|
| 132 |
store = kb[user_id]
|