Spaces:
Running
Running
| import os | |
| import json | |
| import gradio as gr | |
| from uuid import uuid4 | |
| from pprint import pprint | |
| from dotenv import load_dotenv | |
| from qdrant_client import QdrantClient | |
| from fastembed import TextEmbedding | |
| from langchain_core.documents import Document | |
| from src.utils.qdrant_vector_store import QdrantVectorStore, RetrievalMode | |
| from src.utils.fastembed_manager import add_custom_embedding_model | |
| from src.utils.fastembed_sparse import FastEmbedSparse | |
| from qdrant_client import QdrantClient | |
| from qdrant_client.http import models | |
| load_dotenv() | |
| COLLECTION_NAME = "test_collection" | |
| qdrant_api_key = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhY2Nlc3MiOiJtIn0.T97XMDCPTieAz5kVDkKtF0_HU_9BkFA71tH2j4WovkU" | |
| qdrant_endpoint = "https://9ea9b30f-4284-455b-bbae-65e4e458ed35.europe-west3-0.gcp.cloud.qdrant.io" | |
| qdrant_client = QdrantClient( | |
| url=qdrant_endpoint, | |
| api_key=qdrant_api_key, | |
| prefer_grpc=True, | |
| ) | |
| sparse_embeddings = FastEmbedSparse(model_name="Qdrant/BM25") | |
| embedding = add_custom_embedding_model( | |
| model_name="models/Vietnamese_Embedding_OnnX_Quantized", | |
| source_model="Mint1456/Vietnamese_Embedding_OnnX_Quantized", | |
| dim=1024, | |
| source_file="model.onnx" | |
| ) | |
| client = QdrantVectorStore( | |
| client=qdrant_client, | |
| collection_name=COLLECTION_NAME, | |
| embedding=embedding, | |
| sparse_embedding=sparse_embeddings, | |
| retrieval_mode=RetrievalMode.HYBRID, | |
| ) | |
| def search_document(query, top_k, search_type, slider_lambda): | |
| if not query.strip(): | |
| return "⚠️ Enter query to look up!" | |
| try: | |
| if search_type == "Default": | |
| hits = client.similarity_search_with_score(query=query,k=top_k) | |
| else: | |
| hits = client.max_marginal_relevance_search_with_score(query=query, k=top_k, lambda_mult=slider_lambda) | |
| except Exception as e: | |
| print("error", e) | |
| total_found = len(hits) | |
| if total_found == 0: | |
| return json.dumps([], indent=2) | |
| # Nếu tìm được 10 mà đòi 15 -> chỉ lấy 10. Nếu tìm được 100 mà đòi 15 -> lấy 15 | |
| safe_k = min(top_k, total_found) | |
| results = [] | |
| for i in range(safe_k): | |
| hit = hits[i] | |
| if hit[0].metadata.get('parent_chunking', None) is not None: | |
| content = hit[0].metadata['parent_chunking'] | |
| elif hit[0].metadata.get('type', None) == "intro": | |
| content = hit[0].page_content | |
| else: | |
| content = None | |
| results.append({ | |
| "Score": round(hit[1], 4), | |
| "Content": content, | |
| # "Metadata:": {k: v for k, v in hit[0].metadata.items() if k != "page_content"} | |
| }) | |
| return json.dumps(results, indent=2, ensure_ascii=False) | |
| # --- GIAO DIỆN GRADIO --- | |
| with gr.Blocks(title="Qdrant Vector DB Demo") as demo: | |
| gr.Markdown("# 🚀 Demo Qdrant Vector Search") | |
| gr.Markdown("Tool test nhanh khả năng thêm dữ liệu và tìm kiếm ngữ nghĩa (Semantic Search).") | |
| with gr.Tab("2. Tìm Kiếm (Search)"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| txt_query = gr.Textbox(label="Câu truy vấn", placeholder="Ví dụ: Tìm về một số thông tin trên website Bệnh Viện Tâm Anh", lines=2) | |
| gr.Examples( | |
| examples=[ | |
| "Rủi ro khi khâu cổ tử cung", | |
| "Biến chứng của tràn dịch phổi", | |
| "Triệu chứng của viêm phế quản", | |
| "Phòng ngừa đau tim" | |
| ], | |
| inputs=txt_query, | |
| label="Ví dụ mẫu (Click để chọn)" | |
| ) | |
| # Component mới: Chọn thuật toán | |
| radio_type = gr.Radio( | |
| choices=["Default", "MMR"], | |
| value="Default", | |
| label="Search Type", | |
| info="Default: Giống nhất | MMR: Đa dạng kết quả" | |
| ) | |
| # Component mới: Slider cho MMR | |
| # visible=False mặc định, sẽ hiện khi chọn MMR (nếu bạn muốn làm xịn, ở đây để luôn True cho dễ) | |
| slider_lambda = gr.Slider( | |
| minimum=0.0, maximum=1.0, value=0.5, step=0.1, | |
| label="Độ đa dạng (Lambda)", | |
| info="1.0 = Chính xác nhất (như Default), 0.0 = Đa dạng nhất" | |
| ) | |
| slider_k = gr.Slider(minimum=1, maximum=20, value=3, step=1, label="Số lượng kết quả (Top K)") | |
| btn_search = gr.Button("🔍 Tìm kiếm ngay", variant="primary") | |
| with gr.Column(scale=2): | |
| out_search = gr.Code(label="Kết quả trả về (JSON)", language="json") | |
| # Cập nhật inputs truyền vào hàm search | |
| btn_search.click( | |
| search_document, | |
| inputs=[txt_query, slider_k, radio_type, slider_lambda], | |
| outputs=out_search | |
| ) | |
| print("Launching server on 0.0.0.0:7860...") | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=False, | |
| show_error=True, | |
| ) |