from myTextEmbedding import * import gradio as gr def generate_chunk_emb(m, chunk_data): with torch.no_grad(): emb = m(chunk_data, device = "cpu") return emb def search_document(s, chunk_data, chunk_emb, m, topk=3): question = [s] with torch.no_grad(): result_score = m.cos_score(m(question, device = "cpu").expand(chunk_emb.shape),chunk_emb) #result_score = m.cos_score(m(question, device = "cpu"),chunk_emb) print(result_score) _,idxs = torch.topk(result_score,topk) print([result_score.flatten()[idx] for idx in idxs.flatten().tolist()]) print(idxs.flatten().tolist()) print(chunk_data) print(len(chunk_data)) return [chunk_data[idx] for idx in idxs.flatten().tolist() if idx < len(chunk_data)] # create the student training model class TrainStudent(nn.Module): def __init__(self, student_model): super().__init__() self.student_model = student_model def forward(self, s1, teacher_model): emb_student = self.student_model(s1) emb_teacher = teacher_model(s1) mse = (emb_student - emb_teacher).pow(2).mean() return mse chunk_data = generate_chunk_data(["AI","moon","brain"]) student_model=torch.load("myTextEmbeddingStudent.pt",map_location='cpu').student_model # create the embedding vector database chunk_emb = generate_chunk_emb(student_model, chunk_data) #new_chunk_data = [] #new_chunk_emb = tensor([]) def addNewConcepts(user_concepts): return user_concepts def search(input, user_concepts): if user_concepts: new_chunk_data = generate_chunk_data(user_concepts.split(",")) new_chunk_emb = generate_chunk_emb(student_model, new_chunk_data) result = search_document(input, new_chunk_data, new_chunk_emb, student_model) else: result = search_document(input, chunk_data, chunk_emb, student_model) return " ".join(result) with gr.Blocks() as demo: gr.HTML("""

Sentence Embedding and Vector Database

""") search_result = gr.Textbox(show_label=False, placeholder="Search Result", lines=8) with gr.Row(): with gr.Column(scale=1): new_concept_box = gr.Textbox(show_label=False, placeholder="Add new concepts", lines=8) #addConceptBtn = gr.Button("Add concepts") with gr.Column(scale=4): user_input = gr.Textbox(show_label=False, placeholder="Enter question on the concept...", lines=8) searchBtn = gr.Button("Search", variant="primary") searchBtn.click( search, [user_input], [search_result], show_progress=True, ) #addConceptBtn.click(addNewConcepts, [user_concepts], [new_concept_box]) searchBtn.click(search, inputs=[user_input, new_concept_box], outputs=[search_result], show_progress=True) demo.queue().launch(share=False, inbrowser=True)