File size: 5,459 Bytes
6ae201e
93ecd47
 
6ae201e
 
 
 
 
 
 
fe60887
6ae201e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93ecd47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6d5654f
a6bffe9
fe60887
 
93ecd47
 
 
 
 
 
 
fe60887
93ecd47
 
 
 
 
fe60887
93ecd47
 
 
 
fe60887
93ecd47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
#from myTextEmbedding import *
import gradio as gr

import torch
import torch.nn as nn
from torch import tensor 
from transformers import BertModel, BertTokenizer
#import gzip
#import pandas as pd
import requests
import pickle

class EmbeddingModel(nn.Module):
    def __init__(self, bertName = "bert-base-uncased"): # other bert models can also be supported
        super().__init__()
        self.bertName = bertName
        # use BERT model
        self.tokenizer = BertTokenizer.from_pretrained(self.bertName)
        self.model = BertModel.from_pretrained(self.bertName)        
       
    def forward(self, s, device = "cuda"):
        # get tokens, which also include attention_mask
        tokens = self.tokenizer(s, return_tensors='pt', padding = "max_length", truncation = True, max_length = 256).to(device)
        
        # get token embeddings
        output = self.model(**tokens)
        tokens_embeddings = output.last_hidden_state
        #print("tokens_embeddings:" + str(tokens_embeddings.shape))
        
        # mean pooling to get text embedding
        embeddings = tokens_embeddings * tokens.attention_mask[...,None] # [B, T, emb]
        #print("embeddings:" + str(embeddings.shape))
        
        embeddings = embeddings.sum(1) # [B, emb]
        valid_tokens = tokens.attention_mask.sum(1) # [B]
        embeddings = embeddings / valid_tokens[...,None] # [B, emb]    
        
        return embeddings

    # from scratch: nn.CosineSimilarity(dim = 1)(q,a)
    def cos_score(self, q, a): 
        q_norm = q / (q.pow(2).sum(dim=1, keepdim=True).pow(0.5))
        r_norm = a / (a.pow(2).sum(dim=1, keepdim=True).pow(0.5))
        return (q_norm @ r_norm.T).diagonal()
    
# contrastive training
class TrainModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.m = EmbeddingModel("bert-base-uncased")

    def forward(self, s1, s2, score):        
        cos_score = self.m.cos_score(self.m(s1), self.m(s2))
        loss = nn.MSELoss()(cos_score, score)
        return loss, cos_score
    
def searchWiki(s):
    response = requests.get(
            'https://en.wikipedia.org/w/api.php',
            params={
                'action': 'query',
                'format': 'json',
                'titles': s,
                'prop': 'extracts',
                'exintro': True,
                'explaintext': True,
            }
        ).json()
    page = next(iter(response['query']['pages'].values()))
    return page['extract'].replace("\n","")

# sentence chunking
def chunk(w):
    return w.split(".")

def generate_chunk_data(concepts):
    wiki_data = [searchWiki(c).replace("\n","") for c in concepts]
    chunk_data = []
    for w in wiki_data:
        chunk_data = chunk_data + chunk(w) 

    chunk_data = [c.strip()+"." for c in chunk_data]
    while '.' in chunk_data:
        chunk_data.remove('.')
    
    return chunk_data

def generate_chunk_emb(m, chunk_data):
    with torch.no_grad():
        emb = m(chunk_data, device = "cpu")
    return emb

def search_document(s, chunk_data, chunk_emb, m, topk=3):
    question = [s]
    with torch.no_grad():
        result_score = m.cos_score(m(question, device = "cpu").expand(chunk_emb.shape),chunk_emb)
        #result_score = m.cos_score(m(question, device = "cpu"),chunk_emb)
    print(result_score)
    _,idxs = torch.topk(result_score,topk)
    print([result_score.flatten()[idx] for idx in idxs.flatten().tolist()])
    print(idxs.flatten().tolist())
    print(chunk_data)
    print(len(chunk_data))
    return [chunk_data[idx] for idx in idxs.flatten().tolist() if idx < len(chunk_data)]

# create the student training model
class TrainStudent(nn.Module):
    def __init__(self, student_model):
        super().__init__()
        self.student_model = student_model

    def forward(self, s1, teacher_model):
        emb_student = self.student_model(s1)
        emb_teacher = teacher_model(s1)
        mse = (emb_student - emb_teacher).pow(2).mean()
        return mse
    
student_model=torch.load("myTextEmbeddingStudent.pt",map_location='cpu').student_model.eval()

with open("vector_database.pkl","rb") as f:
    vector_database=pickle.load(f)

def addNewConcepts(user_concepts):

    return user_concepts

def search(input, user_concepts):

    result = search_document(input, vector_database["chunk_data"], vector_database["chunk_emb"], student_model)

    return " ".join(result)

with gr.Blocks() as demo:
    gr.HTML("""<h1 align="center">Sentence Embedding and Vector Database</h1>""")
 
    search_result = gr.Textbox(show_label=False, placeholder="Search Result", lines=8)

    with gr.Row():
        with gr.Column(scale=1):
            new_concept_box = gr.Textbox(show_label=False, placeholder="Currently supported concepts in vector database:" + str(vector_database["concepts"]), lines=8)
            #addConceptBtn = gr.Button("Add concepts")
        with gr.Column(scale=4):
            user_input = gr.Textbox(show_label=False, placeholder="Enter question on the concept...", lines=8)
            searchBtn = gr.Button("Search", variant="primary")

    
    searchBtn.click(
        search,
        [user_input],
        [search_result],
        show_progress=True,
    )
    #addConceptBtn.click(addNewConcepts, [user_concepts], [new_concept_box])

    searchBtn.click(search, inputs=[user_input, new_concept_box], outputs=[search_result], show_progress=True)

demo.queue().launch(share=False, inbrowser=True)