PBThuong96 commited on
Commit
d409128
·
verified ·
1 Parent(s): 1b326bd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -16
app.py CHANGED
@@ -1,13 +1,16 @@
1
  import os
2
  import sys
 
 
 
 
3
  try:
4
  __import__("pysqlite3")
5
  sys.modules["sqlite3"] = sys.modules.pop("pysqlite3")
6
  except ImportError:
7
- pass
8
- import logging
9
  import chromadb
10
- import gradio as gr
11
  from langchain_google_genai import ChatGoogleGenerativeAI
12
  from langchain_chroma import Chroma
13
  from langchain_huggingface import HuggingFaceEmbeddings
@@ -21,8 +24,9 @@ from langchain.retrievers.document_compressors import CrossEncoderReranker
21
  from langchain_community.cross_encoders import HuggingFaceCrossEncoder
22
  from langchain_core.documents import Document
23
 
 
24
  GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
25
- DB_PATH = "chroma_db"
26
 
27
  logging.basicConfig(level=logging.INFO, format="%(asctime)s %(message)s")
28
 
@@ -34,61 +38,71 @@ def get_category_vn_name(cat_code):
34
  "association": "🌐 Hiệp Hội"
35
  }.get(cat_code, "Khác")
36
 
 
37
  def get_retrievers():
38
  if not os.path.exists(DB_PATH):
39
- raise FileNotFoundError("❌ Chưa upload folder 'chroma_db'!")
40
 
 
41
  embedding = HuggingFaceEmbeddings(model_name="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2")
42
  vectorstore = Chroma(persist_directory=DB_PATH, embedding_function=embedding)
43
 
 
44
  all_data = vectorstore.get()
45
  splits = [Document(page_content=txt, metadata=m) for txt, m in zip(all_data['documents'], all_data['metadatas'])]
46
 
 
47
  vec_fast = vectorstore.as_retriever(search_kwargs={"k": 5, "filter": {"category": "drug_info"}})
48
  drug_docs = [d for d in splits if d.metadata.get("category") == "drug_info"]
49
  bm25_fast = BM25Retriever.from_documents(drug_docs) if drug_docs else None
50
- bm25_fast.k = 5 if bm25_fast else 5
51
 
52
  fast_retriever = EnsembleRetriever(retrievers=[bm25_fast, vec_fast], weights=[0.4, 0.6]) if bm25_fast else vec_fast
53
 
 
54
  cats = ["local_regimen", "moh_regimen", "association", "drug_info"]
55
  vec_deep = vectorstore.as_retriever(search_kwargs={"k": 25, "filter": {"category": {"$in": cats}}})
56
  deep_docs = [d for d in splits if d.metadata.get("category") in cats]
57
  bm25_deep = BM25Retriever.from_documents(deep_docs) if deep_docs else None
58
- bm25_deep.k = 25 if bm25_deep else 25
59
 
60
  ensemble = EnsembleRetriever(retrievers=[bm25_deep, vec_deep], weights=[0.5, 0.5]) if bm25_deep else vec_deep
61
 
 
62
  reranker = HuggingFaceCrossEncoder(model_name="BAAI/bge-reranker-v2-m3")
63
  compressor = CrossEncoderReranker(model=reranker, top_n=10)
64
  deep_retriever = ContextualCompressionRetriever(base_compressor=compressor, base_retriever=ensemble)
65
 
66
  return fast_retriever, deep_retriever
67
 
 
68
  class DeepMedBot:
69
  def __init__(self):
70
  self.ready = False
 
 
71
  try:
72
  self.fast_retriever, self.deep_retriever = get_retrievers()
73
  self.llm = ChatGoogleGenerativeAI(model="gemini-2.5-flash", temperature=0.2, google_api_key=GOOGLE_API_KEY)
74
  self._build_chains()
75
  self.ready = True
76
  except Exception as e:
77
- logging.error(f"Lỗi: {e}")
78
 
79
  def _build_chains(self):
80
- # Prompt Bảng cho Thuốc
81
  fast_sys = (
82
- "Bạn là Dược sĩ. Tra cứu [💊 Thuốc Nội Bộ] và trả lời bằng **Bảng Markdown**:\n"
83
- "| Tên thuốc | Hoạt chất | Hàm lượng | ĐVT | Ghi chú |\n"
 
84
  "| --- | --- | --- | --- | --- |\n"
85
- "Nếu không thấy, báo: '❌ Không trong kho'."
86
  "Context:\n{context}"
87
  )
88
  fast_chain = create_stuff_documents_chain(self.llm, ChatPromptTemplate.from_messages([("system", fast_sys), ("human", "{input}")]))
89
  self.fast_chain = create_retrieval_chain(self.fast_retriever, fast_chain)
90
 
91
- # Prompt Phác đồ ưu tiên Thanh Ba
92
  deep_sys = (
93
  "Bạn là Bác sĩ Trưởng khoa.\n"
94
  "1. **Tìm phác đồ:** Ưu tiên tuyệt đối [🏥 Phác Đồ Thanh Ba]. Nếu không có mới dùng [Bộ Y Tế].\n"
@@ -105,7 +119,7 @@ class DeepMedBot:
105
  self.deep_chain = create_retrieval_chain(self.deep_retriever, deep_chain)
106
 
107
  def chat(self, msg, history, mode):
108
- if not self.ready: return "⚠️ Đang khởi động... Vui lòng đợi 1 phút."
109
  chain = self.deep_chain if mode == "Chuyên sâu" else self.fast_chain
110
  res = chain.invoke({"input": msg})
111
 
@@ -120,9 +134,13 @@ bot = DeepMedBot()
120
  def respond(message, history, mode):
121
  return bot.chat(message, history, mode)
122
 
123
- gr.ChatInterface(
124
  fn=respond,
125
  additional_inputs=[gr.Radio(["Tra cứu nhanh (Chỉ thuốc)", "Chuyên sâu"], value="Tra cứu nhanh (Chỉ thuốc)", label="Chế độ")],
126
  title="TTYT Thanh Ba - Hỗ trợ Lâm sàng",
 
127
  css=".gradio-container {min_height: 600px}"
128
- ).launch()
 
 
 
 
1
  import os
2
  import sys
3
+ import logging
4
+ import gradio as gr
5
+
6
+ # --- 1. SỬA LỖI SQLITE TRÊN HUGGING FACE (BẮT BUỘC ĐỂ ĐẦU FILE) ---
7
  try:
8
  __import__("pysqlite3")
9
  sys.modules["sqlite3"] = sys.modules.pop("pysqlite3")
10
  except ImportError:
11
+ pass # Nếu chạy local không có pysqlite3 thì bỏ qua
12
+
13
  import chromadb
 
14
  from langchain_google_genai import ChatGoogleGenerativeAI
15
  from langchain_chroma import Chroma
16
  from langchain_huggingface import HuggingFaceEmbeddings
 
24
  from langchain_community.cross_encoders import HuggingFaceCrossEncoder
25
  from langchain_core.documents import Document
26
 
27
+ # --- CẤU HÌNH ---
28
  GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
29
+ DB_PATH = "chroma_db"
30
 
31
  logging.basicConfig(level=logging.INFO, format="%(asctime)s %(message)s")
32
 
 
38
  "association": "🌐 Hiệp Hội"
39
  }.get(cat_code, "Khác")
40
 
41
+ # --- 2. LOAD DB ĐÃ CÓ (KHÔNG BUILD LẠI) ---
42
  def get_retrievers():
43
  if not os.path.exists(DB_PATH):
44
+ raise FileNotFoundError("❌ LỖI: Chưa upload folder 'chroma_db' lên Hugging Face!")
45
 
46
+ logging.info("--- Đang tải dữ liệu... ---")
47
  embedding = HuggingFaceEmbeddings(model_name="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2")
48
  vectorstore = Chroma(persist_directory=DB_PATH, embedding_function=embedding)
49
 
50
+ # Tái tạo BM25 từ VectorStore
51
  all_data = vectorstore.get()
52
  splits = [Document(page_content=txt, metadata=m) for txt, m in zip(all_data['documents'], all_data['metadatas'])]
53
 
54
+ # Mode 1: FAST (Chỉ thuốc)
55
  vec_fast = vectorstore.as_retriever(search_kwargs={"k": 5, "filter": {"category": "drug_info"}})
56
  drug_docs = [d for d in splits if d.metadata.get("category") == "drug_info"]
57
  bm25_fast = BM25Retriever.from_documents(drug_docs) if drug_docs else None
58
+ if bm25_fast: bm25_fast.k = 5
59
 
60
  fast_retriever = EnsembleRetriever(retrievers=[bm25_fast, vec_fast], weights=[0.4, 0.6]) if bm25_fast else vec_fast
61
 
62
+ # Mode 2: DEEP (Ưu tiên Thanh Ba)
63
  cats = ["local_regimen", "moh_regimen", "association", "drug_info"]
64
  vec_deep = vectorstore.as_retriever(search_kwargs={"k": 25, "filter": {"category": {"$in": cats}}})
65
  deep_docs = [d for d in splits if d.metadata.get("category") in cats]
66
  bm25_deep = BM25Retriever.from_documents(deep_docs) if deep_docs else None
67
+ if bm25_deep: bm25_deep.k = 25
68
 
69
  ensemble = EnsembleRetriever(retrievers=[bm25_deep, vec_deep], weights=[0.5, 0.5]) if bm25_deep else vec_deep
70
 
71
+ # Rerank
72
  reranker = HuggingFaceCrossEncoder(model_name="BAAI/bge-reranker-v2-m3")
73
  compressor = CrossEncoderReranker(model=reranker, top_n=10)
74
  deep_retriever = ContextualCompressionRetriever(base_compressor=compressor, base_retriever=ensemble)
75
 
76
  return fast_retriever, deep_retriever
77
 
78
+ # --- 3. BOT LOGIC ---
79
  class DeepMedBot:
80
  def __init__(self):
81
  self.ready = False
82
+ if not GOOGLE_API_KEY:
83
+ return
84
  try:
85
  self.fast_retriever, self.deep_retriever = get_retrievers()
86
  self.llm = ChatGoogleGenerativeAI(model="gemini-2.5-flash", temperature=0.2, google_api_key=GOOGLE_API_KEY)
87
  self._build_chains()
88
  self.ready = True
89
  except Exception as e:
90
+ logging.error(f"Lỗi khởi tạo: {e}")
91
 
92
  def _build_chains(self):
93
+ # Prompt Nhanh (Bảng Thuốc)
94
  fast_sys = (
95
+ "Bạn là Dược sĩ Lâm sàng.\n"
96
+ "Tra cứu [💊 Thuốc Nội Bộ] trả lời bằng **Bảng Markdown**:\n"
97
+ "| Tên thuốc | Hoạt chất | Hàm lượng | Đơn vị | Ghi chú |\n"
98
  "| --- | --- | --- | --- | --- |\n"
99
+ "Nếu không thấy, báo: '❌ Không tìm thấy trong kho'."
100
  "Context:\n{context}"
101
  )
102
  fast_chain = create_stuff_documents_chain(self.llm, ChatPromptTemplate.from_messages([("system", fast_sys), ("human", "{input}")]))
103
  self.fast_chain = create_retrieval_chain(self.fast_retriever, fast_chain)
104
 
105
+ # Prompt Chuyên sâu (Phác đồ Thanh Ba + Bảng)
106
  deep_sys = (
107
  "Bạn là Bác sĩ Trưởng khoa.\n"
108
  "1. **Tìm phác đồ:** Ưu tiên tuyệt đối [🏥 Phác Đồ Thanh Ba]. Nếu không có mới dùng [Bộ Y Tế].\n"
 
119
  self.deep_chain = create_retrieval_chain(self.deep_retriever, deep_chain)
120
 
121
  def chat(self, msg, history, mode):
122
+ if not self.ready: return "⚠️ Đang khởi động hoặc Lỗi (Xem Logs trên Hugging Face)..."
123
  chain = self.deep_chain if mode == "Chuyên sâu" else self.fast_chain
124
  res = chain.invoke({"input": msg})
125
 
 
134
  def respond(message, history, mode):
135
  return bot.chat(message, history, mode)
136
 
137
+ demo = gr.ChatInterface(
138
  fn=respond,
139
  additional_inputs=[gr.Radio(["Tra cứu nhanh (Chỉ thuốc)", "Chuyên sâu"], value="Tra cứu nhanh (Chỉ thuốc)", label="Chế độ")],
140
  title="TTYT Thanh Ba - Hỗ trợ Lâm sàng",
141
+ description="Hệ thống tra cứu Phác đồ & Thuốc nội bộ.",
142
  css=".gradio-container {min_height: 600px}"
143
+ )
144
+
145
+ if __name__ == "__main__":
146
+ demo.launch()