wb-droid's picture
first commit
93ecd47
raw
history blame
2.89 kB
from myTextEmbedding import *
import gradio as gr
def generate_chunk_emb(m, chunk_data):
with torch.no_grad():
emb = m(chunk_data, device = "cpu")
return emb
def search_document(s, chunk_data, chunk_emb, m, topk=3):
question = [s]
with torch.no_grad():
result_score = m.cos_score(m(question, device = "cpu").expand(chunk_emb.shape),chunk_emb)
#result_score = m.cos_score(m(question, device = "cpu"),chunk_emb)
print(result_score)
_,idxs = torch.topk(result_score,topk)
print([result_score.flatten()[idx] for idx in idxs.flatten().tolist()])
print(idxs.flatten().tolist())
print(chunk_data)
print(len(chunk_data))
return [chunk_data[idx] for idx in idxs.flatten().tolist() if idx < len(chunk_data)]
# create the student training model
class TrainStudent(nn.Module):
def __init__(self, student_model):
super().__init__()
self.student_model = student_model
def forward(self, s1, teacher_model):
emb_student = self.student_model(s1)
emb_teacher = teacher_model(s1)
mse = (emb_student - emb_teacher).pow(2).mean()
return mse
chunk_data = generate_chunk_data(["AI","moon","brain"])
student_model=torch.load("myTextEmbeddingStudent.pt",map_location='cpu').student_model
# create the embedding vector database
chunk_emb = generate_chunk_emb(student_model, chunk_data)
#new_chunk_data = []
#new_chunk_emb = tensor([])
def addNewConcepts(user_concepts):
return user_concepts
def search(input, user_concepts):
if user_concepts:
new_chunk_data = generate_chunk_data(user_concepts.split(","))
new_chunk_emb = generate_chunk_emb(student_model, new_chunk_data)
result = search_document(input, new_chunk_data, new_chunk_emb, student_model)
else:
result = search_document(input, chunk_data, chunk_emb, student_model)
return " ".join(result)
with gr.Blocks() as demo:
gr.HTML("""<h1 align="center">Sentence Embedding and Vector Database</h1>""")
search_result = gr.Textbox(show_label=False, placeholder="Search Result", lines=8)
with gr.Row():
with gr.Column(scale=1):
new_concept_box = gr.Textbox(show_label=False, placeholder="Add new concepts", lines=8)
#addConceptBtn = gr.Button("Add concepts")
with gr.Column(scale=4):
user_input = gr.Textbox(show_label=False, placeholder="Enter question on the concept...", lines=8)
searchBtn = gr.Button("Search", variant="primary")
searchBtn.click(
search,
[user_input],
[search_result],
show_progress=True,
)
#addConceptBtn.click(addNewConcepts, [user_concepts], [new_concept_box])
searchBtn.click(search, inputs=[user_input, new_concept_box], outputs=[search_result], show_progress=True)
demo.queue().launch(share=False, inbrowser=True)