Update app.py
Browse files
app.py
CHANGED
|
@@ -311,6 +311,7 @@ def chat_bert_context(question, history):
|
|
| 311 |
|
| 312 |
#-------------------------------------Bi-BERT-Encoder------------------------------------------#
|
| 313 |
MAX_LENGTH = 128
|
|
|
|
| 314 |
# Define function for mean-pooling
|
| 315 |
def mean_pool(token_embeds: torch.tensor, attention_mask: torch.tensor) -> torch.tensor:
|
| 316 |
in_mask = attention_mask.unsqueeze(-1).expand(token_embeds.size()).float()
|
|
@@ -372,12 +373,13 @@ def chat_bi_bert(question):
|
|
| 372 |
cosine_similarities = cosine_similarity([question], question_embeds).flatten()
|
| 373 |
top_indice = np.argmax(cosine_similarities, axis=0)
|
| 374 |
answer = df['answer'].iloc[top_indice]
|
|
|
|
| 375 |
return answer
|
| 376 |
|
| 377 |
|
| 378 |
|
| 379 |
#-------------------------------------Bi+Cross-BERT-Encoder------------------------------------------#
|
| 380 |
-
|
| 381 |
|
| 382 |
#Define class for CrossEncoderBert
|
| 383 |
class CrossEncoderBert(torch.nn.Module):
|
|
@@ -418,9 +420,9 @@ def chat_cross_bert(question):
|
|
| 418 |
|
| 419 |
# Process scores for finetuned model
|
| 420 |
scores = ce_scores.cpu().numpy()
|
| 421 |
-
|
| 422 |
# print(f"{corpus[scores_ix]}")
|
| 423 |
-
return corpus[
|
| 424 |
|
| 425 |
# gradio part
|
| 426 |
def echo(message, history, model):
|
|
|
|
| 311 |
|
| 312 |
#-------------------------------------Bi-BERT-Encoder------------------------------------------#
|
| 313 |
MAX_LENGTH = 128
|
| 314 |
+
inverted_answer = dict(enumerate(df.answer.tolist()))
|
| 315 |
# Define function for mean-pooling
|
| 316 |
def mean_pool(token_embeds: torch.tensor, attention_mask: torch.tensor) -> torch.tensor:
|
| 317 |
in_mask = attention_mask.unsqueeze(-1).expand(token_embeds.size()).float()
|
|
|
|
| 373 |
cosine_similarities = cosine_similarity([question], question_embeds).flatten()
|
| 374 |
top_indice = np.argmax(cosine_similarities, axis=0)
|
| 375 |
answer = df['answer'].iloc[top_indice]
|
| 376 |
+
answer = inverted_answer[top_indice]
|
| 377 |
return answer
|
| 378 |
|
| 379 |
|
| 380 |
|
| 381 |
#-------------------------------------Bi+Cross-BERT-Encoder------------------------------------------#
|
| 382 |
+
|
| 383 |
|
| 384 |
#Define class for CrossEncoderBert
|
| 385 |
class CrossEncoderBert(torch.nn.Module):
|
|
|
|
| 420 |
|
| 421 |
# Process scores for finetuned model
|
| 422 |
scores = ce_scores.cpu().numpy()
|
| 423 |
+
ix = np.argmax(scores)
|
| 424 |
# print(f"{corpus[scores_ix]}")
|
| 425 |
+
return corpus[ix]
|
| 426 |
|
| 427 |
# gradio part
|
| 428 |
def echo(message, history, model):
|