wb-droid commited on
Commit
93ecd47
·
1 Parent(s): 794d1e1

first commit

Browse files
Files changed (4) hide show
  1. app.py +80 -0
  2. myTextEmbedding.py +98 -0
  3. myTextEmbeddingStudent.pt +3 -0
  4. requirements.txt +3 -0
app.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from myTextEmbedding import *
2
+ import gradio as gr
3
+
4
+ def generate_chunk_emb(m, chunk_data):
5
+ with torch.no_grad():
6
+ emb = m(chunk_data, device = "cpu")
7
+ return emb
8
+
9
+ def search_document(s, chunk_data, chunk_emb, m, topk=3):
10
+ question = [s]
11
+ with torch.no_grad():
12
+ result_score = m.cos_score(m(question, device = "cpu").expand(chunk_emb.shape),chunk_emb)
13
+ #result_score = m.cos_score(m(question, device = "cpu"),chunk_emb)
14
+ print(result_score)
15
+ _,idxs = torch.topk(result_score,topk)
16
+ print([result_score.flatten()[idx] for idx in idxs.flatten().tolist()])
17
+ print(idxs.flatten().tolist())
18
+ print(chunk_data)
19
+ print(len(chunk_data))
20
+ return [chunk_data[idx] for idx in idxs.flatten().tolist() if idx < len(chunk_data)]
21
+
22
+ # create the student training model
23
+ class TrainStudent(nn.Module):
24
+ def __init__(self, student_model):
25
+ super().__init__()
26
+ self.student_model = student_model
27
+
28
+ def forward(self, s1, teacher_model):
29
+ emb_student = self.student_model(s1)
30
+ emb_teacher = teacher_model(s1)
31
+ mse = (emb_student - emb_teacher).pow(2).mean()
32
+ return mse
33
+
34
+ chunk_data = generate_chunk_data(["AI","moon","brain"])
35
+ student_model=torch.load("myTextEmbeddingStudent.pt",map_location='cpu').student_model
36
+ # create the embedding vector database
37
+ chunk_emb = generate_chunk_emb(student_model, chunk_data)
38
+
39
+ #new_chunk_data = []
40
+ #new_chunk_emb = tensor([])
41
+ def addNewConcepts(user_concepts):
42
+
43
+ return user_concepts
44
+
45
+ def search(input, user_concepts):
46
+
47
+ if user_concepts:
48
+ new_chunk_data = generate_chunk_data(user_concepts.split(","))
49
+ new_chunk_emb = generate_chunk_emb(student_model, new_chunk_data)
50
+ result = search_document(input, new_chunk_data, new_chunk_emb, student_model)
51
+ else:
52
+ result = search_document(input, chunk_data, chunk_emb, student_model)
53
+
54
+ return " ".join(result)
55
+
56
+ with gr.Blocks() as demo:
57
+ gr.HTML("""<h1 align="center">Sentence Embedding and Vector Database</h1>""")
58
+
59
+ search_result = gr.Textbox(show_label=False, placeholder="Search Result", lines=8)
60
+
61
+ with gr.Row():
62
+ with gr.Column(scale=1):
63
+ new_concept_box = gr.Textbox(show_label=False, placeholder="Add new concepts", lines=8)
64
+ #addConceptBtn = gr.Button("Add concepts")
65
+ with gr.Column(scale=4):
66
+ user_input = gr.Textbox(show_label=False, placeholder="Enter question on the concept...", lines=8)
67
+ searchBtn = gr.Button("Search", variant="primary")
68
+
69
+
70
+ searchBtn.click(
71
+ search,
72
+ [user_input],
73
+ [search_result],
74
+ show_progress=True,
75
+ )
76
+ #addConceptBtn.click(addNewConcepts, [user_concepts], [new_concept_box])
77
+
78
+ searchBtn.click(search, inputs=[user_input, new_concept_box], outputs=[search_result], show_progress=True)
79
+
80
+ demo.queue().launch(share=False, inbrowser=True)
myTextEmbedding.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch import tensor
4
+ from transformers import BertModel, BertTokenizer
5
+ #import gzip
6
+ import pandas as pd
7
+ import requests
8
+
9
+
10
+ class EmbeddingModel(nn.Module):
11
+ def __init__(self, bertName = "bert-base-uncased"): # other bert models can also be supported
12
+ super().__init__()
13
+ self.bertName = bertName
14
+ # use BERT model
15
+ self.tokenizer = BertTokenizer.from_pretrained(self.bertName)
16
+ self.model = BertModel.from_pretrained(self.bertName)
17
+
18
+ def forward(self, s, device = "cuda"):
19
+ # get tokens, which also include attention_mask
20
+ tokens = self.tokenizer(s, return_tensors='pt', padding = "max_length", truncation = True, max_length = 256).to(device)
21
+
22
+ # get token embeddings
23
+ output = self.model(**tokens)
24
+ tokens_embeddings = output.last_hidden_state
25
+ #print("tokens_embeddings:" + str(tokens_embeddings.shape))
26
+
27
+ # mean pooling to get text embedding
28
+ embeddings = tokens_embeddings * tokens.attention_mask[...,None] # [B, T, emb]
29
+ #print("embeddings:" + str(embeddings.shape))
30
+
31
+ embeddings = embeddings.sum(1) # [B, emb]
32
+ valid_tokens = tokens.attention_mask.sum(1) # [B]
33
+ embeddings = embeddings / valid_tokens[...,None] # [B, emb]
34
+
35
+ return embeddings
36
+
37
+ # from scratch: nn.CosineSimilarity(dim = 1)(q,a)
38
+ def cos_score(self, q, a):
39
+ q_norm = q / (q.pow(2).sum(dim=1, keepdim=True).pow(0.5))
40
+ r_norm = a / (a.pow(2).sum(dim=1, keepdim=True).pow(0.5))
41
+ return (q_norm @ r_norm.T).diagonal()
42
+
43
+ # contrastive training
44
+ class TrainModel(nn.Module):
45
+ def __init__(self):
46
+ super().__init__()
47
+ self.m = EmbeddingModel("bert-base-uncased")
48
+
49
+ def forward(self, s1, s2, score):
50
+ cos_score = self.m.cos_score(self.m(s1), self.m(s2))
51
+ loss = nn.MSELoss()(cos_score, score)
52
+ return loss, cos_score
53
+
54
+ def searchWiki(s):
55
+ response = requests.get(
56
+ 'https://en.wikipedia.org/w/api.php',
57
+ params={
58
+ 'action': 'query',
59
+ 'format': 'json',
60
+ 'titles': s,
61
+ 'prop': 'extracts',
62
+ 'exintro': True,
63
+ 'explaintext': True,
64
+ }
65
+ ).json()
66
+ page = next(iter(response['query']['pages'].values()))
67
+ return page['extract'].replace("\n","")
68
+
69
+ # sentence chunking
70
+ def chunk(w):
71
+ return w.split(".")
72
+
73
+ def generate_chunk_data(concepts):
74
+ wiki_data = [searchWiki(c).replace("\n","") for c in concepts]
75
+ chunk_data = []
76
+ for w in wiki_data:
77
+ chunk_data = chunk_data + chunk(w)
78
+
79
+ chunk_data = [c.strip()+"." for c in chunk_data]
80
+ while '.' in chunk_data:
81
+ chunk_data.remove('.')
82
+
83
+ return chunk_data
84
+
85
+ def generate_chunk_emb(m, chunk_data):
86
+ with torch.no_grad():
87
+ emb = m(chunk_data, device = "cpu")
88
+ return emb
89
+
90
+ def search_document(s, chunk_data, chunk_emb, m, topk=3):
91
+ question = [s]
92
+ with torch.no_grad():
93
+ result_score = m.cos_score(m(question, device = "cpu").expand(chunk_emb.shape),chunk_emb)
94
+ print(result_score)
95
+ _,idxs = torch.topk(result_score,topk)
96
+ print([result_score.flatten()[idx] for idx in idxs.flatten().tolist()])
97
+ return [chunk_data[idx] for idx in idxs.flatten().tolist()]
98
+
myTextEmbeddingStudent.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:84960ed3f791210853072638f665f07dc70e344688bf77a24c10e7d556a175bf
3
+ size 268739587
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ torch
2
+ transformers
3
+ pandas