Spaces:
Running
Running
| import json | |
| import argparse | |
| from pathlib import Path | |
| from typing import List | |
| import gradio as gr | |
| import faiss | |
| import numpy as np | |
| import torch | |
| from tqdm import tqdm | |
| from sentence_transformers import SentenceTransformer | |
| file_example = """Please upload a JSON file with a "text" field (with optional "title" field). For example | |
| ```JSON | |
| [ | |
| {"title": "", "text": "This an example text without the title"}, | |
| {"title": "Title A", "text": "This an example text with the title"}, | |
| {"title": "Title B", "text": "This an example text with the title"}, | |
| ] | |
| ``` | |
| Due to the computation resources, please test with small scale data (<1000). | |
| """ | |
| def create_index(embeddings, use_gpu): | |
| index = faiss.IndexFlatIP(len(embeddings[0])) | |
| embeddings = np.asarray(embeddings, dtype=np.float32) | |
| if use_gpu: | |
| co = faiss.GpuMultipleClonerOptions() | |
| co.shard = True | |
| co.useFloat16 = True | |
| index = faiss.index_cpu_to_all_gpus(index, co=co) | |
| index.add(embeddings) | |
| return index | |
| def upload_file_fn( | |
| file_path: List[str], | |
| progress: gr.Progress = gr.Progress(track_tqdm=True) | |
| ): | |
| try: | |
| with open(file_path) as f: | |
| document_data = json.load(f) | |
| gr.Info(f"Upload {len(document_data)} documents.") | |
| if len(document_data) > 1000: | |
| gr.Info(f"Cut uploaded documents to 1000 due to the computation resource.") | |
| document_data = document_data[: 1000] | |
| documents = [] | |
| for obj in document_data: | |
| text = obj["title"] + "\n" + obj["text"] if obj.get("title") else obj["text"] | |
| if len(str(text).strip()): | |
| documents.append(text) | |
| else: | |
| documents.append(model.tokenizer.eos_token) | |
| except Exception as e: | |
| print(e) | |
| gr.Error("Read the file failed. Please check the data format.") | |
| gr.Error(str(e)) | |
| return None, gr.update(interactive=False) | |
| if len(documents) < 5: | |
| gr.Error("Please upload more than 53 documents.") | |
| return None, gr.update(interactive=False) | |
| # documents_embeddings = model.encode(documents, show_progress_bar=True) | |
| documents_embeddings = [] | |
| batch_size = 16 | |
| for i in tqdm(range(0, len(documents), batch_size)): | |
| batch_documents = documents[i: i+batch_size] | |
| batch_embeddings = model.encode(batch_documents, show_progress_bar=True) | |
| documents_embeddings.extend(batch_embeddings) | |
| document_index = create_index(documents_embeddings, use_gpu=False) | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| torch.cuda.ipc_collect() | |
| document_state = {"document_data": document_data, "document_index": document_index} | |
| return document_state, gr.update(interactive=True) | |
| def clear_file_fn(): | |
| return None, gr.update(interactive=True) | |
| def retrieve_document_fn(question, document_states, instruct): | |
| num_retrieval_doc = 5 | |
| if document_states is None: | |
| gr.Warning("Please upload documents first!") | |
| return [None for i in range(num_retrieval_doc)] + [None] | |
| document_data, document_index = document_states["document_data"], document_states["document_index"] | |
| question_with_inst = str(instruct) + str(question) | |
| if len(question_with_inst.strip()) == 0: | |
| gr.Warning("Please enter a non-empty query.") | |
| return None, None, None, None, None, document_states | |
| question_embedding = model.encode([question_with_inst]) | |
| batch_scores, batch_inxs = document_index.search(question_embedding, k=min(len(document_data), 150)) | |
| answers = [document_data[i]["text"] for i in batch_inxs[0][:num_retrieval_doc]] | |
| return answers[0], answers[1], answers[2], answers[3], answers[4],document_states | |
| def main(args): | |
| global model | |
| model = SentenceTransformer( | |
| args.model_name_or_path, | |
| revision=args.revision, | |
| ) | |
| document_state = gr.State() | |
| with open(Path(__file__).parent / "resources/head.html") as html_file: | |
| head = html_file.read().strip() | |
| with gr.Blocks(theme=gr.themes.Soft(font="sans-serif").set(background_fill_primary="linear-gradient(90deg, #e3ffe7 0%, #d9e7ff 100%)", background_fill_primary_dark="linear-gradient(90deg, #4b6cb7 0%, #182848 100%)",), | |
| head=head, | |
| css=Path(__file__).parent / "resources/styles.css", | |
| title="KaLM-Embedding", | |
| fill_height=True, | |
| analytics_enabled=False) as demo: | |
| gr.Markdown(file_example) | |
| doc_files_box = gr.File(label="Upload Documents", file_types=[".json"], file_count="single") | |
| model_selection = gr.Radio(["HIT-TMG/KaLM-embedding-multilingual-mini-instruct-v1.5"], value="HIT-TMG/KaLM-embedding-multilingual-mini-instruct-v1.5", label="Model Selection", interactive=False) | |
| retrieval_interface = gr.Interface( | |
| fn=retrieve_document_fn, | |
| inputs=[gr.Textbox(label="Query"), document_state], | |
| outputs=[gr.Text(label="Recall-1"), gr.Text(label="Recall-2"), gr.Text(label="Recall-3"), gr.Text(label="Recall-4"), gr.Text(label="Recall-5"), gr.State()], | |
| additional_inputs=[gr.Textbox("Instruct: Given a query, retrieve documents that answer the query. \n Query: ", label="Instruct of Query", lines=2)], | |
| concurrency_limit=1, | |
| allow_flagging="never", | |
| ) | |
| # retrieval_interface.input_components[0] = gr.update(interactive=False) | |
| doc_files_box.upload( | |
| upload_file_fn, | |
| [doc_files_box], | |
| [document_state, retrieval_interface.input_components[0]], | |
| queue=True, | |
| trigger_mode="once" | |
| ) | |
| doc_files_box.clear( | |
| clear_file_fn, | |
| None, | |
| [document_state, retrieval_interface.input_components[0]], | |
| queue=True, | |
| trigger_mode="once" | |
| ) | |
| demo.launch() | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--model_name_or_path", type=str, default="HIT-TMG/KaLM-embedding-multilingual-mini-instruct-v1.5") | |
| parser.add_argument("--revision", type=str, default=None) | |
| args = parser.parse_args() | |
| main(args) |