StKirill commited on
Commit
844cfc1
·
verified ·
1 Parent(s): cb9b058

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +138 -14
app.py CHANGED
@@ -32,6 +32,9 @@ nltk.download('stopwords')
32
  # Take Rachel as main character
33
  df = pd.read_csv("rachel_friends.csv") # read the database into a data frame
34
 
 
 
 
35
  # Define function for text normalization
36
  def text_normalization(text):
37
  text = str(text).lower() # convert to all lower letters
@@ -126,7 +129,7 @@ def chat_tfidf_context(question, history):
126
  answer = df['answer'].loc[index_value]
127
 
128
  return answer
129
- #------------------------------------------------------------------------------------------------#
130
  punkt = [p for p in punctuation] + ["`", "``" ,"''", "'"]
131
 
132
  def tokenize(sent: str) -> str:
@@ -201,7 +204,7 @@ def chat_word2vec_context(question, history):
201
 
202
  return answer
203
 
204
- #------------------------------------------------------------------------------------------------#
205
 
206
  # Let's try bert model by elastic and with e5
207
  model_name = "distilbert/distilbert-base-uncased"
@@ -296,23 +299,144 @@ def chat_bert_context(question, history):
296
  answer = df['answer'].loc[relevant_indice]
297
 
298
  return answer
299
- #------------------------------------------------------------------------------------------------#
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
300
  # gradio part
301
  def echo(message, history, model):
302
 
303
- if model=="TF-IDF":
304
  # answer = chat_tfidf(message)
305
- answer = chat_tfidf_context(message, history)
306
- return answer
307
-
308
- elif model=="W2V":
309
  # answer = chat_word2vec(message)
310
- answer = chat_word2vec_context(message, history)
311
- return answer
 
 
 
 
 
 
 
 
312
 
313
- elif model=="BERT":
314
- answer = chat_bert_context(message, history)
315
- return answer
316
 
317
 
318
 
@@ -321,7 +445,7 @@ title = "Chatbot who speaks like Rachel from Friends"
321
  description = "You have a good opportunity to have a dialog with friend's actor - Rachel Green"
322
 
323
  # model = gr.CheckboxGroup(["TF-IDF", "W2V", "BERT", "BI-Encoder", "Cross-Encoder"], label="Model", info="What model do you want to use?", value="TF-IDF")
324
- model = gr.Dropdown(["TF-IDF", "W2V", "BERT", "BI-Encoder", "Cross-Encoder"], label="Retrieval model", info="What model do you want to use?", value="TF-IDF")
325
 
326
  with gr.Blocks() as demo:
327
 
 
32
  # Take Rachel as main character
33
  df = pd.read_csv("rachel_friends.csv") # read the database into a data frame
34
 
35
+
36
+ #-------------------------------------TF-IDF------------------------------------------#
37
+
38
  # Define function for text normalization
39
  def text_normalization(text):
40
  text = str(text).lower() # convert to all lower letters
 
129
  answer = df['answer'].loc[index_value]
130
 
131
  return answer
132
+ #-------------------------------------W2V------------------------------------------#
133
  punkt = [p for p in punctuation] + ["`", "``" ,"''", "'"]
134
 
135
  def tokenize(sent: str) -> str:
 
204
 
205
  return answer
206
 
207
+ #-------------------------------------BERT------------------------------------------#
208
 
209
  # Let's try bert model by elastic and with e5
210
  model_name = "distilbert/distilbert-base-uncased"
 
299
  answer = df['answer'].loc[relevant_indice]
300
 
301
  return answer
302
+
303
+ #-------------------------------------Bi-BERT-Encoder------------------------------------------#
304
+
305
+ # Define function for mean-pooling
306
+ def mean_pool(token_embeds: torch.tensor, attention_mask: torch.tensor) -> torch.tensor:
307
+ in_mask = attention_mask.unsqueeze(-1).expand(token_embeds.size()).float()
308
+ pool = torch.sum(token_embeds * in_mask, 1) / torch.clamp(in_mask.sum(1), min=1e-9)
309
+ return pool
310
+
311
+ # Define function for tokenization of the sentence and encoding it
312
+ def encode(input_texts: list[str], tokenizer: AutoTokenizer, model: AutoModel, device: str = "cpu"
313
+ ) -> torch.tensor:
314
+
315
+ model.eval()
316
+ tokenized_texts = tokenizer(input_texts, max_length=128,
317
+ padding='max_length', truncation=True, return_tensors="pt")
318
+ token_embeds = model(tokenized_texts["input_ids"].to(device),
319
+ tokenized_texts["attention_mask"].to(device)).last_hidden_state
320
+ pooled_embeds = mean_pool(token_embeds, tokenized_texts["attention_mask"].to(device))
321
+ return pooled_embeds
322
+
323
+ # Define architecture for bi-bert-encoder
324
+ class Sbert(torch.nn.Module):
325
+ def __init__(self, max_length: int = 128):
326
+ super().__init__()
327
+ self.max_length = max_length
328
+ self.bert_model = AutoModel.from_pretrained('distilbert-base-uncased')
329
+ self.bert_tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased')
330
+ self.linear = torch.nn.Linear(self.bert_model.config.hidden_size * 3, 1)
331
+ # self.sigmoid = torch.nn.Sigmoid()
332
+
333
+ def forward(self, data: datasets.arrow_dataset.Dataset) -> torch.tensor:
334
+ question_input_ids = data["question_input_ids"].to(device)
335
+ question_attention_mask = data["question_attention_mask"].to(device)
336
+ answer_input_ids = data["answer_input_ids"].to(device)
337
+ answer_attention_mask = data["answer_attention_mask"].to(device)
338
+
339
+ out_question = self.bert_model(question_input_ids, question_attention_mask)
340
+ out_answer = self.bert_model(answer_input_ids, answer_attention_mask)
341
+ question_embeds = out_question.last_hidden_state
342
+ answer_embeds = out_answer.last_hidden_state
343
+
344
+ pooled_question_embeds = mean_pool(question_embeds, question_attention_mask)
345
+ pooled_answer_embeds = mean_pool(answer_embeds, answer_attention_mask)
346
+
347
+ embeds = torch.cat([pooled_question_embeds, pooled_answer_embeds,
348
+ torch.abs(pooled_question_embeds - pooled_answer_embeds)],
349
+ dim=-1)
350
+ # return self.sigmoid(self.linear(embeds))
351
+ return self.linear(embeds)
352
+
353
+ # Initialize the model
354
+ model_bi_encoder = Sbert().to(device)
355
+ # Load weights from training step
356
+ model_bi_encoder.bert_model.from_pretrained("models/friends_bi_encoder")
357
+
358
+ # Load question embeds
359
+ question_embeds = np.load("bi_bert_question.npy")
360
+
361
+ def chat_bi_bert(question):
362
+ question = encode(question, model_bi_encoder.bert_tokenizer, model_bi_encoder.bert_model, device).squeeze().cpu().detach().numpy()
363
+ cosine_similarities = cosine_similarity([question], question_embeds).flatten()
364
+ top_indice = np.argmax(cosine_similarities, axis=0)
365
+ answer = df['answer'].iloc[ind]
366
+ return answer
367
+
368
+
369
+
370
+ #-------------------------------------Bi+Cross-BERT-Encoder------------------------------------------#
371
+ inverted_question = dict(enumerate(df_base.answer.tolist()))
372
+
373
+ #Define class for CrossEncoderBert
374
+ class CrossEncoderBert(torch.nn.Module):
375
+ def __init__(self, max_length: int = MAX_LENGTH):
376
+ super().__init__()
377
+ self.max_length = max_length
378
+ self.bert_model = AutoModel.from_pretrained('distilbert-base-uncased')
379
+ self.bert_tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased')
380
+ self.linear = torch.nn.Linear(self.bert_model.config.hidden_size, 1)
381
+
382
+ def forward(self, input_ids, attention_mask):
383
+ outputs = self.bert_model(input_ids=input_ids, attention_mask=attention_mask)
384
+ pooled_output = outputs.last_hidden_state[:, 0] # Use the CLS token's output
385
+ return self.linear(pooled_output)
386
+
387
+ model_cross_encoder = CrossEncoderBert().to(device)
388
+ model_cross_encoder.bert_model.from_pretrained("models/friends_cross_encoder")
389
+
390
+ def chat_cross_bert(tokenizer, finetuned_ce, base_bert, query):
391
+
392
+ question_encoded = encode(question, model_bi_encoder.bert_tokenizer, model_bi_encoder.bert_model, device).squeeze().cpu().detach().numpy()
393
+ cosine_similarities = cosine_similarity([question_encoded], question_embeds).flatten()
394
+ topk_indices = np.argsort(cosine_similarities, axis=0)[::-1][:5]
395
+ topk_indices=topk_indices.tolist()
396
+ corpus = [inverted_answer[ind] for ind in topk_indices]
397
+
398
+ queries = [question] * len(corpus)
399
+
400
+ tokenized_texts = tokenizer(
401
+ queries, corpus, max_length=MAX_LENGTH, padding=True, truncation=True, return_tensors="pt"
402
+ ).to(device)
403
+
404
+ # Finetuned CrossEncoder model scoring
405
+ with torch.no_grad():
406
+ ce_scores = model_cross_encoder(tokenized_texts['input_ids'], tokenized_texts['attention_mask']).squeeze(-1)
407
+ ce_scores = torch.sigmoid(ce_scores) # Apply sigmoid if needed
408
+
409
+
410
+ # Process scores for finetuned model
411
+ scores = ce_scores.cpu().numpy()
412
+ scores_ix = np.argmax(scores)
413
+ # print(f"{corpus[scores_ix]}")
414
+ return corpus[scores_ix]
415
+
416
  # gradio part
417
  def echo(message, history, model):
418
 
419
+ if model=="TF-IDF":
420
  # answer = chat_tfidf(message)
421
+ answer = chat_tfidf_context(message, history)
422
+ return answer
423
+
424
+ elif model=="W2V":
425
  # answer = chat_word2vec(message)
426
+ answer = chat_word2vec_context(message, history)
427
+ return answer
428
+
429
+ elif model=="BERT":
430
+ answer = chat_bert_context(message, history)
431
+ return answer
432
+
433
+ elif model=="Bi-BERT-Encoder":
434
+ answer = chat_bi_bert(message)
435
+ return answer
436
 
437
+ elif model=="Bi+Cross-BERT-Encoder":
438
+ answer = chat_cross_bert(message)
439
+ return answer
440
 
441
 
442
 
 
445
  description = "You have a good opportunity to have a dialog with friend's actor - Rachel Green"
446
 
447
  # model = gr.CheckboxGroup(["TF-IDF", "W2V", "BERT", "BI-Encoder", "Cross-Encoder"], label="Model", info="What model do you want to use?", value="TF-IDF")
448
+ model = gr.Dropdown(["TF-IDF", "W2V", "BERT", "Bi-BERT-Encoder", "Bi+Cross-BERT-Encoder"], label="Retrieval model", info="What model do you want to use?", value="TF-IDF")
449
 
450
  with gr.Blocks() as demo:
451