Spaces:
Sleeping
Sleeping
| from myTextEmbedding import * | |
| import gradio as gr | |
| def generate_chunk_emb(m, chunk_data): | |
| with torch.no_grad(): | |
| emb = m(chunk_data, device = "cpu") | |
| return emb | |
| def search_document(s, chunk_data, chunk_emb, m, topk=3): | |
| question = [s] | |
| with torch.no_grad(): | |
| result_score = m.cos_score(m(question, device = "cpu").expand(chunk_emb.shape),chunk_emb) | |
| #result_score = m.cos_score(m(question, device = "cpu"),chunk_emb) | |
| print(result_score) | |
| _,idxs = torch.topk(result_score,topk) | |
| print([result_score.flatten()[idx] for idx in idxs.flatten().tolist()]) | |
| print(idxs.flatten().tolist()) | |
| print(chunk_data) | |
| print(len(chunk_data)) | |
| return [chunk_data[idx] for idx in idxs.flatten().tolist() if idx < len(chunk_data)] | |
| # create the student training model | |
| class TrainStudent(nn.Module): | |
| def __init__(self, student_model): | |
| super().__init__() | |
| self.student_model = student_model | |
| def forward(self, s1, teacher_model): | |
| emb_student = self.student_model(s1) | |
| emb_teacher = teacher_model(s1) | |
| mse = (emb_student - emb_teacher).pow(2).mean() | |
| return mse | |
| chunk_data = generate_chunk_data(["AI","moon","brain"]) | |
| student_model=torch.load("myTextEmbeddingStudent.pt",map_location='cpu').student_model | |
| # create the embedding vector database | |
| chunk_emb = generate_chunk_emb(student_model, chunk_data) | |
| #new_chunk_data = [] | |
| #new_chunk_emb = tensor([]) | |
| def addNewConcepts(user_concepts): | |
| return user_concepts | |
| def search(input, user_concepts): | |
| if user_concepts: | |
| new_chunk_data = generate_chunk_data(user_concepts.split(",")) | |
| new_chunk_emb = generate_chunk_emb(student_model, new_chunk_data) | |
| result = search_document(input, new_chunk_data, new_chunk_emb, student_model) | |
| else: | |
| result = search_document(input, chunk_data, chunk_emb, student_model) | |
| return " ".join(result) | |
| with gr.Blocks() as demo: | |
| gr.HTML("""<h1 align="center">Sentence Embedding and Vector Database</h1>""") | |
| search_result = gr.Textbox(show_label=False, placeholder="Search Result", lines=8) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| new_concept_box = gr.Textbox(show_label=False, placeholder="Add new concepts", lines=8) | |
| #addConceptBtn = gr.Button("Add concepts") | |
| with gr.Column(scale=4): | |
| user_input = gr.Textbox(show_label=False, placeholder="Enter question on the concept...", lines=8) | |
| searchBtn = gr.Button("Search", variant="primary") | |
| searchBtn.click( | |
| search, | |
| [user_input], | |
| [search_result], | |
| show_progress=True, | |
| ) | |
| #addConceptBtn.click(addNewConcepts, [user_concepts], [new_concept_box]) | |
| searchBtn.click(search, inputs=[user_input, new_concept_box], outputs=[search_result], show_progress=True) | |
| demo.queue().launch(share=False, inbrowser=True) |