Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,13 +1,16 @@
|
|
| 1 |
import os
|
| 2 |
import sys
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
try:
|
| 4 |
__import__("pysqlite3")
|
| 5 |
sys.modules["sqlite3"] = sys.modules.pop("pysqlite3")
|
| 6 |
except ImportError:
|
| 7 |
-
pass
|
| 8 |
-
|
| 9 |
import chromadb
|
| 10 |
-
import gradio as gr
|
| 11 |
from langchain_google_genai import ChatGoogleGenerativeAI
|
| 12 |
from langchain_chroma import Chroma
|
| 13 |
from langchain_huggingface import HuggingFaceEmbeddings
|
|
@@ -21,8 +24,9 @@ from langchain.retrievers.document_compressors import CrossEncoderReranker
|
|
| 21 |
from langchain_community.cross_encoders import HuggingFaceCrossEncoder
|
| 22 |
from langchain_core.documents import Document
|
| 23 |
|
|
|
|
| 24 |
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
|
| 25 |
-
DB_PATH = "chroma_db"
|
| 26 |
|
| 27 |
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(message)s")
|
| 28 |
|
|
@@ -34,61 +38,71 @@ def get_category_vn_name(cat_code):
|
|
| 34 |
"association": "🌐 Hiệp Hội"
|
| 35 |
}.get(cat_code, "Khác")
|
| 36 |
|
|
|
|
| 37 |
def get_retrievers():
|
| 38 |
if not os.path.exists(DB_PATH):
|
| 39 |
-
raise FileNotFoundError("❌ Chưa upload folder 'chroma_db'!")
|
| 40 |
|
|
|
|
| 41 |
embedding = HuggingFaceEmbeddings(model_name="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2")
|
| 42 |
vectorstore = Chroma(persist_directory=DB_PATH, embedding_function=embedding)
|
| 43 |
|
|
|
|
| 44 |
all_data = vectorstore.get()
|
| 45 |
splits = [Document(page_content=txt, metadata=m) for txt, m in zip(all_data['documents'], all_data['metadatas'])]
|
| 46 |
|
|
|
|
| 47 |
vec_fast = vectorstore.as_retriever(search_kwargs={"k": 5, "filter": {"category": "drug_info"}})
|
| 48 |
drug_docs = [d for d in splits if d.metadata.get("category") == "drug_info"]
|
| 49 |
bm25_fast = BM25Retriever.from_documents(drug_docs) if drug_docs else None
|
| 50 |
-
bm25_fast.k = 5
|
| 51 |
|
| 52 |
fast_retriever = EnsembleRetriever(retrievers=[bm25_fast, vec_fast], weights=[0.4, 0.6]) if bm25_fast else vec_fast
|
| 53 |
|
|
|
|
| 54 |
cats = ["local_regimen", "moh_regimen", "association", "drug_info"]
|
| 55 |
vec_deep = vectorstore.as_retriever(search_kwargs={"k": 25, "filter": {"category": {"$in": cats}}})
|
| 56 |
deep_docs = [d for d in splits if d.metadata.get("category") in cats]
|
| 57 |
bm25_deep = BM25Retriever.from_documents(deep_docs) if deep_docs else None
|
| 58 |
-
bm25_deep.k = 25
|
| 59 |
|
| 60 |
ensemble = EnsembleRetriever(retrievers=[bm25_deep, vec_deep], weights=[0.5, 0.5]) if bm25_deep else vec_deep
|
| 61 |
|
|
|
|
| 62 |
reranker = HuggingFaceCrossEncoder(model_name="BAAI/bge-reranker-v2-m3")
|
| 63 |
compressor = CrossEncoderReranker(model=reranker, top_n=10)
|
| 64 |
deep_retriever = ContextualCompressionRetriever(base_compressor=compressor, base_retriever=ensemble)
|
| 65 |
|
| 66 |
return fast_retriever, deep_retriever
|
| 67 |
|
|
|
|
| 68 |
class DeepMedBot:
|
| 69 |
def __init__(self):
|
| 70 |
self.ready = False
|
|
|
|
|
|
|
| 71 |
try:
|
| 72 |
self.fast_retriever, self.deep_retriever = get_retrievers()
|
| 73 |
self.llm = ChatGoogleGenerativeAI(model="gemini-2.5-flash", temperature=0.2, google_api_key=GOOGLE_API_KEY)
|
| 74 |
self._build_chains()
|
| 75 |
self.ready = True
|
| 76 |
except Exception as e:
|
| 77 |
-
logging.error(f"Lỗi: {e}")
|
| 78 |
|
| 79 |
def _build_chains(self):
|
| 80 |
-
# Prompt Bảng
|
| 81 |
fast_sys = (
|
| 82 |
-
"Bạn là Dược sĩ
|
| 83 |
-
"
|
|
|
|
| 84 |
"| --- | --- | --- | --- | --- |\n"
|
| 85 |
-
"Nếu không thấy, báo: '❌ Không
|
| 86 |
"Context:\n{context}"
|
| 87 |
)
|
| 88 |
fast_chain = create_stuff_documents_chain(self.llm, ChatPromptTemplate.from_messages([("system", fast_sys), ("human", "{input}")]))
|
| 89 |
self.fast_chain = create_retrieval_chain(self.fast_retriever, fast_chain)
|
| 90 |
|
| 91 |
-
# Prompt Phác đồ
|
| 92 |
deep_sys = (
|
| 93 |
"Bạn là Bác sĩ Trưởng khoa.\n"
|
| 94 |
"1. **Tìm phác đồ:** Ưu tiên tuyệt đối [🏥 Phác Đồ Thanh Ba]. Nếu không có mới dùng [Bộ Y Tế].\n"
|
|
@@ -105,7 +119,7 @@ class DeepMedBot:
|
|
| 105 |
self.deep_chain = create_retrieval_chain(self.deep_retriever, deep_chain)
|
| 106 |
|
| 107 |
def chat(self, msg, history, mode):
|
| 108 |
-
if not self.ready: return "⚠️ Đang khởi động
|
| 109 |
chain = self.deep_chain if mode == "Chuyên sâu" else self.fast_chain
|
| 110 |
res = chain.invoke({"input": msg})
|
| 111 |
|
|
@@ -120,9 +134,13 @@ bot = DeepMedBot()
|
|
| 120 |
def respond(message, history, mode):
|
| 121 |
return bot.chat(message, history, mode)
|
| 122 |
|
| 123 |
-
gr.ChatInterface(
|
| 124 |
fn=respond,
|
| 125 |
additional_inputs=[gr.Radio(["Tra cứu nhanh (Chỉ thuốc)", "Chuyên sâu"], value="Tra cứu nhanh (Chỉ thuốc)", label="Chế độ")],
|
| 126 |
title="TTYT Thanh Ba - Hỗ trợ Lâm sàng",
|
|
|
|
| 127 |
css=".gradio-container {min_height: 600px}"
|
| 128 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import os
|
| 2 |
import sys
|
| 3 |
+
import logging
|
| 4 |
+
import gradio as gr
|
| 5 |
+
|
| 6 |
+
# --- 1. SỬA LỖI SQLITE TRÊN HUGGING FACE (BẮT BUỘC ĐỂ ĐẦU FILE) ---
|
| 7 |
try:
|
| 8 |
__import__("pysqlite3")
|
| 9 |
sys.modules["sqlite3"] = sys.modules.pop("pysqlite3")
|
| 10 |
except ImportError:
|
| 11 |
+
pass # Nếu chạy local không có pysqlite3 thì bỏ qua
|
| 12 |
+
|
| 13 |
import chromadb
|
|
|
|
| 14 |
from langchain_google_genai import ChatGoogleGenerativeAI
|
| 15 |
from langchain_chroma import Chroma
|
| 16 |
from langchain_huggingface import HuggingFaceEmbeddings
|
|
|
|
| 24 |
from langchain_community.cross_encoders import HuggingFaceCrossEncoder
|
| 25 |
from langchain_core.documents import Document
|
| 26 |
|
| 27 |
+
# --- CẤU HÌNH ---
|
| 28 |
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
|
| 29 |
+
DB_PATH = "chroma_db"
|
| 30 |
|
| 31 |
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(message)s")
|
| 32 |
|
|
|
|
| 38 |
"association": "🌐 Hiệp Hội"
|
| 39 |
}.get(cat_code, "Khác")
|
| 40 |
|
| 41 |
+
# --- 2. LOAD DB ĐÃ CÓ (KHÔNG BUILD LẠI) ---
|
| 42 |
def get_retrievers():
|
| 43 |
if not os.path.exists(DB_PATH):
|
| 44 |
+
raise FileNotFoundError("❌ LỖI: Chưa upload folder 'chroma_db' lên Hugging Face!")
|
| 45 |
|
| 46 |
+
logging.info("--- Đang tải dữ liệu... ---")
|
| 47 |
embedding = HuggingFaceEmbeddings(model_name="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2")
|
| 48 |
vectorstore = Chroma(persist_directory=DB_PATH, embedding_function=embedding)
|
| 49 |
|
| 50 |
+
# Tái tạo BM25 từ VectorStore
|
| 51 |
all_data = vectorstore.get()
|
| 52 |
splits = [Document(page_content=txt, metadata=m) for txt, m in zip(all_data['documents'], all_data['metadatas'])]
|
| 53 |
|
| 54 |
+
# Mode 1: FAST (Chỉ thuốc)
|
| 55 |
vec_fast = vectorstore.as_retriever(search_kwargs={"k": 5, "filter": {"category": "drug_info"}})
|
| 56 |
drug_docs = [d for d in splits if d.metadata.get("category") == "drug_info"]
|
| 57 |
bm25_fast = BM25Retriever.from_documents(drug_docs) if drug_docs else None
|
| 58 |
+
if bm25_fast: bm25_fast.k = 5
|
| 59 |
|
| 60 |
fast_retriever = EnsembleRetriever(retrievers=[bm25_fast, vec_fast], weights=[0.4, 0.6]) if bm25_fast else vec_fast
|
| 61 |
|
| 62 |
+
# Mode 2: DEEP (Ưu tiên Thanh Ba)
|
| 63 |
cats = ["local_regimen", "moh_regimen", "association", "drug_info"]
|
| 64 |
vec_deep = vectorstore.as_retriever(search_kwargs={"k": 25, "filter": {"category": {"$in": cats}}})
|
| 65 |
deep_docs = [d for d in splits if d.metadata.get("category") in cats]
|
| 66 |
bm25_deep = BM25Retriever.from_documents(deep_docs) if deep_docs else None
|
| 67 |
+
if bm25_deep: bm25_deep.k = 25
|
| 68 |
|
| 69 |
ensemble = EnsembleRetriever(retrievers=[bm25_deep, vec_deep], weights=[0.5, 0.5]) if bm25_deep else vec_deep
|
| 70 |
|
| 71 |
+
# Rerank
|
| 72 |
reranker = HuggingFaceCrossEncoder(model_name="BAAI/bge-reranker-v2-m3")
|
| 73 |
compressor = CrossEncoderReranker(model=reranker, top_n=10)
|
| 74 |
deep_retriever = ContextualCompressionRetriever(base_compressor=compressor, base_retriever=ensemble)
|
| 75 |
|
| 76 |
return fast_retriever, deep_retriever
|
| 77 |
|
| 78 |
+
# --- 3. BOT LOGIC ---
|
| 79 |
class DeepMedBot:
|
| 80 |
def __init__(self):
|
| 81 |
self.ready = False
|
| 82 |
+
if not GOOGLE_API_KEY:
|
| 83 |
+
return
|
| 84 |
try:
|
| 85 |
self.fast_retriever, self.deep_retriever = get_retrievers()
|
| 86 |
self.llm = ChatGoogleGenerativeAI(model="gemini-2.5-flash", temperature=0.2, google_api_key=GOOGLE_API_KEY)
|
| 87 |
self._build_chains()
|
| 88 |
self.ready = True
|
| 89 |
except Exception as e:
|
| 90 |
+
logging.error(f"Lỗi khởi tạo: {e}")
|
| 91 |
|
| 92 |
def _build_chains(self):
|
| 93 |
+
# Prompt Nhanh (Bảng Thuốc)
|
| 94 |
fast_sys = (
|
| 95 |
+
"Bạn là Dược sĩ Lâm sàng.\n"
|
| 96 |
+
"Tra cứu [💊 Thuốc Nội Bộ] và trả lời bằng **Bảng Markdown**:\n"
|
| 97 |
+
"| Tên thuốc | Hoạt chất | Hàm lượng | Đơn vị | Ghi chú |\n"
|
| 98 |
"| --- | --- | --- | --- | --- |\n"
|
| 99 |
+
"Nếu không thấy, báo: '❌ Không tìm thấy trong kho'."
|
| 100 |
"Context:\n{context}"
|
| 101 |
)
|
| 102 |
fast_chain = create_stuff_documents_chain(self.llm, ChatPromptTemplate.from_messages([("system", fast_sys), ("human", "{input}")]))
|
| 103 |
self.fast_chain = create_retrieval_chain(self.fast_retriever, fast_chain)
|
| 104 |
|
| 105 |
+
# Prompt Chuyên sâu (Phác đồ Thanh Ba + Bảng)
|
| 106 |
deep_sys = (
|
| 107 |
"Bạn là Bác sĩ Trưởng khoa.\n"
|
| 108 |
"1. **Tìm phác đồ:** Ưu tiên tuyệt đối [🏥 Phác Đồ Thanh Ba]. Nếu không có mới dùng [Bộ Y Tế].\n"
|
|
|
|
| 119 |
self.deep_chain = create_retrieval_chain(self.deep_retriever, deep_chain)
|
| 120 |
|
| 121 |
def chat(self, msg, history, mode):
|
| 122 |
+
if not self.ready: return "⚠️ Đang khởi động hoặc Lỗi (Xem Logs trên Hugging Face)..."
|
| 123 |
chain = self.deep_chain if mode == "Chuyên sâu" else self.fast_chain
|
| 124 |
res = chain.invoke({"input": msg})
|
| 125 |
|
|
|
|
| 134 |
def respond(message, history, mode):
|
| 135 |
return bot.chat(message, history, mode)
|
| 136 |
|
| 137 |
+
demo = gr.ChatInterface(
|
| 138 |
fn=respond,
|
| 139 |
additional_inputs=[gr.Radio(["Tra cứu nhanh (Chỉ thuốc)", "Chuyên sâu"], value="Tra cứu nhanh (Chỉ thuốc)", label="Chế độ")],
|
| 140 |
title="TTYT Thanh Ba - Hỗ trợ Lâm sàng",
|
| 141 |
+
description="Hệ thống tra cứu Phác đồ & Thuốc nội bộ.",
|
| 142 |
css=".gradio-container {min_height: 600px}"
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
if __name__ == "__main__":
|
| 146 |
+
demo.launch()
|