Update app.py
Browse files
app.py
CHANGED
|
@@ -368,7 +368,7 @@ model_bi_encoder.bert_model.from_pretrained("models/friends_bi_encoder")
|
|
| 368 |
# Load question embeds
|
| 369 |
question_embeds = np.load("bi_bert_question.npy")
|
| 370 |
|
| 371 |
-
def chat_bi_bert(question):
|
| 372 |
question = encode(question, model_bi_encoder.bert_tokenizer, model_bi_encoder.bert_model, device).squeeze().cpu().detach().numpy()
|
| 373 |
cosine_similarities = cosine_similarity([question], question_embeds).flatten()
|
| 374 |
top_indice = np.argmax(cosine_similarities, axis=0)
|
|
@@ -398,7 +398,7 @@ class CrossEncoderBert(torch.nn.Module):
|
|
| 398 |
model_cross_encoder = CrossEncoderBert().to(device)
|
| 399 |
model_cross_encoder.bert_model.from_pretrained("models/friends_cross_encoder")
|
| 400 |
|
| 401 |
-
def chat_cross_bert(question):
|
| 402 |
|
| 403 |
question_encoded = encode(question, model_bi_encoder.bert_tokenizer, model_bi_encoder.bert_model, device).squeeze().cpu().detach().numpy()
|
| 404 |
cosine_similarities = cosine_similarity([question_encoded], question_embeds).flatten()
|
|
@@ -442,11 +442,11 @@ def echo(message, history, model):
|
|
| 442 |
return answer
|
| 443 |
|
| 444 |
elif model=="Bi-BERT-Encoder":
|
| 445 |
-
answer = chat_bi_bert(message)
|
| 446 |
return answer
|
| 447 |
|
| 448 |
elif model=="Bi+Cross-BERT-Encoder":
|
| 449 |
-
answer = chat_cross_bert(message)
|
| 450 |
return answer
|
| 451 |
|
| 452 |
|
|
|
|
| 368 |
# Load question embeds
|
| 369 |
question_embeds = np.load("bi_bert_question.npy")
|
| 370 |
|
| 371 |
+
def chat_bi_bert(question, history):
|
| 372 |
question = encode(question, model_bi_encoder.bert_tokenizer, model_bi_encoder.bert_model, device).squeeze().cpu().detach().numpy()
|
| 373 |
cosine_similarities = cosine_similarity([question], question_embeds).flatten()
|
| 374 |
top_indice = np.argmax(cosine_similarities, axis=0)
|
|
|
|
| 398 |
model_cross_encoder = CrossEncoderBert().to(device)
|
| 399 |
model_cross_encoder.bert_model.from_pretrained("models/friends_cross_encoder")
|
| 400 |
|
| 401 |
+
def chat_cross_bert(question, history):
|
| 402 |
|
| 403 |
question_encoded = encode(question, model_bi_encoder.bert_tokenizer, model_bi_encoder.bert_model, device).squeeze().cpu().detach().numpy()
|
| 404 |
cosine_similarities = cosine_similarity([question_encoded], question_embeds).flatten()
|
|
|
|
| 442 |
return answer
|
| 443 |
|
| 444 |
elif model=="Bi-BERT-Encoder":
|
| 445 |
+
answer = chat_bi_bert(message, history)
|
| 446 |
return answer
|
| 447 |
|
| 448 |
elif model=="Bi+Cross-BERT-Encoder":
|
| 449 |
+
answer = chat_cross_bert(message, history)
|
| 450 |
return answer
|
| 451 |
|
| 452 |
|