Spaces:
Sleeping
Sleeping
File size: 5,459 Bytes
6ae201e 93ecd47 6ae201e fe60887 6ae201e 93ecd47 6d5654f a6bffe9 fe60887 93ecd47 fe60887 93ecd47 fe60887 93ecd47 fe60887 93ecd47 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 |
#from myTextEmbedding import *
import gradio as gr
import torch
import torch.nn as nn
from torch import tensor
from transformers import BertModel, BertTokenizer
#import gzip
#import pandas as pd
import requests
import pickle
class EmbeddingModel(nn.Module):
def __init__(self, bertName = "bert-base-uncased"): # other bert models can also be supported
super().__init__()
self.bertName = bertName
# use BERT model
self.tokenizer = BertTokenizer.from_pretrained(self.bertName)
self.model = BertModel.from_pretrained(self.bertName)
def forward(self, s, device = "cuda"):
# get tokens, which also include attention_mask
tokens = self.tokenizer(s, return_tensors='pt', padding = "max_length", truncation = True, max_length = 256).to(device)
# get token embeddings
output = self.model(**tokens)
tokens_embeddings = output.last_hidden_state
#print("tokens_embeddings:" + str(tokens_embeddings.shape))
# mean pooling to get text embedding
embeddings = tokens_embeddings * tokens.attention_mask[...,None] # [B, T, emb]
#print("embeddings:" + str(embeddings.shape))
embeddings = embeddings.sum(1) # [B, emb]
valid_tokens = tokens.attention_mask.sum(1) # [B]
embeddings = embeddings / valid_tokens[...,None] # [B, emb]
return embeddings
# from scratch: nn.CosineSimilarity(dim = 1)(q,a)
def cos_score(self, q, a):
q_norm = q / (q.pow(2).sum(dim=1, keepdim=True).pow(0.5))
r_norm = a / (a.pow(2).sum(dim=1, keepdim=True).pow(0.5))
return (q_norm @ r_norm.T).diagonal()
# contrastive training
class TrainModel(nn.Module):
def __init__(self):
super().__init__()
self.m = EmbeddingModel("bert-base-uncased")
def forward(self, s1, s2, score):
cos_score = self.m.cos_score(self.m(s1), self.m(s2))
loss = nn.MSELoss()(cos_score, score)
return loss, cos_score
def searchWiki(s):
response = requests.get(
'https://en.wikipedia.org/w/api.php',
params={
'action': 'query',
'format': 'json',
'titles': s,
'prop': 'extracts',
'exintro': True,
'explaintext': True,
}
).json()
page = next(iter(response['query']['pages'].values()))
return page['extract'].replace("\n","")
# sentence chunking
def chunk(w):
return w.split(".")
def generate_chunk_data(concepts):
wiki_data = [searchWiki(c).replace("\n","") for c in concepts]
chunk_data = []
for w in wiki_data:
chunk_data = chunk_data + chunk(w)
chunk_data = [c.strip()+"." for c in chunk_data]
while '.' in chunk_data:
chunk_data.remove('.')
return chunk_data
def generate_chunk_emb(m, chunk_data):
with torch.no_grad():
emb = m(chunk_data, device = "cpu")
return emb
def search_document(s, chunk_data, chunk_emb, m, topk=3):
question = [s]
with torch.no_grad():
result_score = m.cos_score(m(question, device = "cpu").expand(chunk_emb.shape),chunk_emb)
#result_score = m.cos_score(m(question, device = "cpu"),chunk_emb)
print(result_score)
_,idxs = torch.topk(result_score,topk)
print([result_score.flatten()[idx] for idx in idxs.flatten().tolist()])
print(idxs.flatten().tolist())
print(chunk_data)
print(len(chunk_data))
return [chunk_data[idx] for idx in idxs.flatten().tolist() if idx < len(chunk_data)]
# create the student training model
class TrainStudent(nn.Module):
def __init__(self, student_model):
super().__init__()
self.student_model = student_model
def forward(self, s1, teacher_model):
emb_student = self.student_model(s1)
emb_teacher = teacher_model(s1)
mse = (emb_student - emb_teacher).pow(2).mean()
return mse
student_model=torch.load("myTextEmbeddingStudent.pt",map_location='cpu').student_model.eval()
with open("vector_database.pkl","rb") as f:
vector_database=pickle.load(f)
def addNewConcepts(user_concepts):
return user_concepts
def search(input, user_concepts):
result = search_document(input, vector_database["chunk_data"], vector_database["chunk_emb"], student_model)
return " ".join(result)
with gr.Blocks() as demo:
gr.HTML("""<h1 align="center">Sentence Embedding and Vector Database</h1>""")
search_result = gr.Textbox(show_label=False, placeholder="Search Result", lines=8)
with gr.Row():
with gr.Column(scale=1):
new_concept_box = gr.Textbox(show_label=False, placeholder="Currently supported concepts in vector database:" + str(vector_database["concepts"]), lines=8)
#addConceptBtn = gr.Button("Add concepts")
with gr.Column(scale=4):
user_input = gr.Textbox(show_label=False, placeholder="Enter question on the concept...", lines=8)
searchBtn = gr.Button("Search", variant="primary")
searchBtn.click(
search,
[user_input],
[search_result],
show_progress=True,
)
#addConceptBtn.click(addNewConcepts, [user_concepts], [new_concept_box])
searchBtn.click(search, inputs=[user_input, new_concept_box], outputs=[search_result], show_progress=True)
demo.queue().launch(share=False, inbrowser=True) |