Update app.py
Browse files
app.py
CHANGED
|
@@ -32,6 +32,9 @@ nltk.download('stopwords')
|
|
| 32 |
# Take Rachel as main character
|
| 33 |
df = pd.read_csv("rachel_friends.csv") # read the database into a data frame
|
| 34 |
|
|
|
|
|
|
|
|
|
|
| 35 |
# Define function for text normalization
|
| 36 |
def text_normalization(text):
|
| 37 |
text = str(text).lower() # convert to all lower letters
|
|
@@ -126,7 +129,7 @@ def chat_tfidf_context(question, history):
|
|
| 126 |
answer = df['answer'].loc[index_value]
|
| 127 |
|
| 128 |
return answer
|
| 129 |
-
|
| 130 |
punkt = [p for p in punctuation] + ["`", "``" ,"''", "'"]
|
| 131 |
|
| 132 |
def tokenize(sent: str) -> str:
|
|
@@ -201,7 +204,7 @@ def chat_word2vec_context(question, history):
|
|
| 201 |
|
| 202 |
return answer
|
| 203 |
|
| 204 |
-
|
| 205 |
|
| 206 |
# Let's try bert model by elastic and with e5
|
| 207 |
model_name = "distilbert/distilbert-base-uncased"
|
|
@@ -296,23 +299,144 @@ def chat_bert_context(question, history):
|
|
| 296 |
answer = df['answer'].loc[relevant_indice]
|
| 297 |
|
| 298 |
return answer
|
| 299 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 300 |
# gradio part
|
| 301 |
def echo(message, history, model):
|
| 302 |
|
| 303 |
-
|
| 304 |
# answer = chat_tfidf(message)
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
| 309 |
# answer = chat_word2vec(message)
|
| 310 |
-
|
| 311 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 312 |
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
|
| 317 |
|
| 318 |
|
|
@@ -321,7 +445,7 @@ title = "Chatbot who speaks like Rachel from Friends"
|
|
| 321 |
description = "You have a good opportunity to have a dialog with friend's actor - Rachel Green"
|
| 322 |
|
| 323 |
# model = gr.CheckboxGroup(["TF-IDF", "W2V", "BERT", "BI-Encoder", "Cross-Encoder"], label="Model", info="What model do you want to use?", value="TF-IDF")
|
| 324 |
-
model = gr.Dropdown(["TF-IDF", "W2V", "BERT", "
|
| 325 |
|
| 326 |
with gr.Blocks() as demo:
|
| 327 |
|
|
|
|
| 32 |
# Take Rachel as main character
|
| 33 |
df = pd.read_csv("rachel_friends.csv") # read the database into a data frame
|
| 34 |
|
| 35 |
+
|
| 36 |
+
#-------------------------------------TF-IDF------------------------------------------#
|
| 37 |
+
|
| 38 |
# Define function for text normalization
|
| 39 |
def text_normalization(text):
|
| 40 |
text = str(text).lower() # convert to all lower letters
|
|
|
|
| 129 |
answer = df['answer'].loc[index_value]
|
| 130 |
|
| 131 |
return answer
|
| 132 |
+
#-------------------------------------W2V------------------------------------------#
|
| 133 |
punkt = [p for p in punctuation] + ["`", "``" ,"''", "'"]
|
| 134 |
|
| 135 |
def tokenize(sent: str) -> str:
|
|
|
|
| 204 |
|
| 205 |
return answer
|
| 206 |
|
| 207 |
+
#-------------------------------------BERT------------------------------------------#
|
| 208 |
|
| 209 |
# Let's try bert model by elastic and with e5
|
| 210 |
model_name = "distilbert/distilbert-base-uncased"
|
|
|
|
| 299 |
answer = df['answer'].loc[relevant_indice]
|
| 300 |
|
| 301 |
return answer
|
| 302 |
+
|
| 303 |
+
#-------------------------------------Bi-BERT-Encoder------------------------------------------#
|
| 304 |
+
|
| 305 |
+
# Define function for mean-pooling
|
| 306 |
+
def mean_pool(token_embeds: torch.tensor, attention_mask: torch.tensor) -> torch.tensor:
|
| 307 |
+
in_mask = attention_mask.unsqueeze(-1).expand(token_embeds.size()).float()
|
| 308 |
+
pool = torch.sum(token_embeds * in_mask, 1) / torch.clamp(in_mask.sum(1), min=1e-9)
|
| 309 |
+
return pool
|
| 310 |
+
|
| 311 |
+
# Define function for tokenization of the sentence and encoding it
|
| 312 |
+
def encode(input_texts: list[str], tokenizer: AutoTokenizer, model: AutoModel, device: str = "cpu"
|
| 313 |
+
) -> torch.tensor:
|
| 314 |
+
|
| 315 |
+
model.eval()
|
| 316 |
+
tokenized_texts = tokenizer(input_texts, max_length=128,
|
| 317 |
+
padding='max_length', truncation=True, return_tensors="pt")
|
| 318 |
+
token_embeds = model(tokenized_texts["input_ids"].to(device),
|
| 319 |
+
tokenized_texts["attention_mask"].to(device)).last_hidden_state
|
| 320 |
+
pooled_embeds = mean_pool(token_embeds, tokenized_texts["attention_mask"].to(device))
|
| 321 |
+
return pooled_embeds
|
| 322 |
+
|
| 323 |
+
# Define architecture for bi-bert-encoder
|
| 324 |
+
class Sbert(torch.nn.Module):
|
| 325 |
+
def __init__(self, max_length: int = 128):
|
| 326 |
+
super().__init__()
|
| 327 |
+
self.max_length = max_length
|
| 328 |
+
self.bert_model = AutoModel.from_pretrained('distilbert-base-uncased')
|
| 329 |
+
self.bert_tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased')
|
| 330 |
+
self.linear = torch.nn.Linear(self.bert_model.config.hidden_size * 3, 1)
|
| 331 |
+
# self.sigmoid = torch.nn.Sigmoid()
|
| 332 |
+
|
| 333 |
+
def forward(self, data: datasets.arrow_dataset.Dataset) -> torch.tensor:
|
| 334 |
+
question_input_ids = data["question_input_ids"].to(device)
|
| 335 |
+
question_attention_mask = data["question_attention_mask"].to(device)
|
| 336 |
+
answer_input_ids = data["answer_input_ids"].to(device)
|
| 337 |
+
answer_attention_mask = data["answer_attention_mask"].to(device)
|
| 338 |
+
|
| 339 |
+
out_question = self.bert_model(question_input_ids, question_attention_mask)
|
| 340 |
+
out_answer = self.bert_model(answer_input_ids, answer_attention_mask)
|
| 341 |
+
question_embeds = out_question.last_hidden_state
|
| 342 |
+
answer_embeds = out_answer.last_hidden_state
|
| 343 |
+
|
| 344 |
+
pooled_question_embeds = mean_pool(question_embeds, question_attention_mask)
|
| 345 |
+
pooled_answer_embeds = mean_pool(answer_embeds, answer_attention_mask)
|
| 346 |
+
|
| 347 |
+
embeds = torch.cat([pooled_question_embeds, pooled_answer_embeds,
|
| 348 |
+
torch.abs(pooled_question_embeds - pooled_answer_embeds)],
|
| 349 |
+
dim=-1)
|
| 350 |
+
# return self.sigmoid(self.linear(embeds))
|
| 351 |
+
return self.linear(embeds)
|
| 352 |
+
|
| 353 |
+
# Initialize the model
|
| 354 |
+
model_bi_encoder = Sbert().to(device)
|
| 355 |
+
# Load weights from training step
|
| 356 |
+
model_bi_encoder.bert_model.from_pretrained("models/friends_bi_encoder")
|
| 357 |
+
|
| 358 |
+
# Load question embeds
|
| 359 |
+
question_embeds = np.load("bi_bert_question.npy")
|
| 360 |
+
|
| 361 |
+
def chat_bi_bert(question):
|
| 362 |
+
question = encode(question, model_bi_encoder.bert_tokenizer, model_bi_encoder.bert_model, device).squeeze().cpu().detach().numpy()
|
| 363 |
+
cosine_similarities = cosine_similarity([question], question_embeds).flatten()
|
| 364 |
+
top_indice = np.argmax(cosine_similarities, axis=0)
|
| 365 |
+
answer = df['answer'].iloc[ind]
|
| 366 |
+
return answer
|
| 367 |
+
|
| 368 |
+
|
| 369 |
+
|
| 370 |
+
#-------------------------------------Bi+Cross-BERT-Encoder------------------------------------------#
|
| 371 |
+
inverted_question = dict(enumerate(df_base.answer.tolist()))
|
| 372 |
+
|
| 373 |
+
#Define class for CrossEncoderBert
|
| 374 |
+
class CrossEncoderBert(torch.nn.Module):
|
| 375 |
+
def __init__(self, max_length: int = MAX_LENGTH):
|
| 376 |
+
super().__init__()
|
| 377 |
+
self.max_length = max_length
|
| 378 |
+
self.bert_model = AutoModel.from_pretrained('distilbert-base-uncased')
|
| 379 |
+
self.bert_tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased')
|
| 380 |
+
self.linear = torch.nn.Linear(self.bert_model.config.hidden_size, 1)
|
| 381 |
+
|
| 382 |
+
def forward(self, input_ids, attention_mask):
|
| 383 |
+
outputs = self.bert_model(input_ids=input_ids, attention_mask=attention_mask)
|
| 384 |
+
pooled_output = outputs.last_hidden_state[:, 0] # Use the CLS token's output
|
| 385 |
+
return self.linear(pooled_output)
|
| 386 |
+
|
| 387 |
+
model_cross_encoder = CrossEncoderBert().to(device)
|
| 388 |
+
model_cross_encoder.bert_model.from_pretrained("models/friends_cross_encoder")
|
| 389 |
+
|
| 390 |
+
def chat_cross_bert(tokenizer, finetuned_ce, base_bert, query):
|
| 391 |
+
|
| 392 |
+
question_encoded = encode(question, model_bi_encoder.bert_tokenizer, model_bi_encoder.bert_model, device).squeeze().cpu().detach().numpy()
|
| 393 |
+
cosine_similarities = cosine_similarity([question_encoded], question_embeds).flatten()
|
| 394 |
+
topk_indices = np.argsort(cosine_similarities, axis=0)[::-1][:5]
|
| 395 |
+
topk_indices=topk_indices.tolist()
|
| 396 |
+
corpus = [inverted_answer[ind] for ind in topk_indices]
|
| 397 |
+
|
| 398 |
+
queries = [question] * len(corpus)
|
| 399 |
+
|
| 400 |
+
tokenized_texts = tokenizer(
|
| 401 |
+
queries, corpus, max_length=MAX_LENGTH, padding=True, truncation=True, return_tensors="pt"
|
| 402 |
+
).to(device)
|
| 403 |
+
|
| 404 |
+
# Finetuned CrossEncoder model scoring
|
| 405 |
+
with torch.no_grad():
|
| 406 |
+
ce_scores = model_cross_encoder(tokenized_texts['input_ids'], tokenized_texts['attention_mask']).squeeze(-1)
|
| 407 |
+
ce_scores = torch.sigmoid(ce_scores) # Apply sigmoid if needed
|
| 408 |
+
|
| 409 |
+
|
| 410 |
+
# Process scores for finetuned model
|
| 411 |
+
scores = ce_scores.cpu().numpy()
|
| 412 |
+
scores_ix = np.argmax(scores)
|
| 413 |
+
# print(f"{corpus[scores_ix]}")
|
| 414 |
+
return corpus[scores_ix]
|
| 415 |
+
|
| 416 |
# gradio part
|
| 417 |
def echo(message, history, model):
|
| 418 |
|
| 419 |
+
if model=="TF-IDF":
|
| 420 |
# answer = chat_tfidf(message)
|
| 421 |
+
answer = chat_tfidf_context(message, history)
|
| 422 |
+
return answer
|
| 423 |
+
|
| 424 |
+
elif model=="W2V":
|
| 425 |
# answer = chat_word2vec(message)
|
| 426 |
+
answer = chat_word2vec_context(message, history)
|
| 427 |
+
return answer
|
| 428 |
+
|
| 429 |
+
elif model=="BERT":
|
| 430 |
+
answer = chat_bert_context(message, history)
|
| 431 |
+
return answer
|
| 432 |
+
|
| 433 |
+
elif model=="Bi-BERT-Encoder":
|
| 434 |
+
answer = chat_bi_bert(message)
|
| 435 |
+
return answer
|
| 436 |
|
| 437 |
+
elif model=="Bi+Cross-BERT-Encoder":
|
| 438 |
+
answer = chat_cross_bert(message)
|
| 439 |
+
return answer
|
| 440 |
|
| 441 |
|
| 442 |
|
|
|
|
| 445 |
description = "You have a good opportunity to have a dialog with friend's actor - Rachel Green"
|
| 446 |
|
| 447 |
# model = gr.CheckboxGroup(["TF-IDF", "W2V", "BERT", "BI-Encoder", "Cross-Encoder"], label="Model", info="What model do you want to use?", value="TF-IDF")
|
| 448 |
+
model = gr.Dropdown(["TF-IDF", "W2V", "BERT", "Bi-BERT-Encoder", "Bi+Cross-BERT-Encoder"], label="Retrieval model", info="What model do you want to use?", value="TF-IDF")
|
| 449 |
|
| 450 |
with gr.Blocks() as demo:
|
| 451 |
|