PBThuong96 commited on
Commit
3504b37
·
verified ·
1 Parent(s): 7947767

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +69 -50
app.py CHANGED
@@ -3,25 +3,25 @@ import sys
3
  import logging
4
  import gradio as gr
5
 
6
- # --- 1. SỬA LỖI SQLITE TRÊN HUGGING FACE (BẮT BUỘC ĐỂ ĐẦU FILE) ---
7
  try:
8
  __import__("pysqlite3")
9
  sys.modules["sqlite3"] = sys.modules.pop("pysqlite3")
10
  except ImportError:
11
- pass # Nếu chạy local không có pysqlite3 thì bỏ qua
12
 
13
  import chromadb
14
  from langchain_google_genai import ChatGoogleGenerativeAI
15
  from langchain_chroma import Chroma
16
  from langchain_huggingface import HuggingFaceEmbeddings
 
 
17
  from langchain_community.retrievers import BM25Retriever
18
- from langchain.retrievers.ensemble import EnsembleRetriever
 
19
  from langchain.chains import create_retrieval_chain
20
  from langchain.chains.combine_documents import create_stuff_documents_chain
21
  from langchain_core.prompts import ChatPromptTemplate
22
- from langchain.retrievers import ContextualCompressionRetriever
23
- from langchain.retrievers.document_compressors import CrossEncoderReranker
24
- from langchain_community.cross_encoders import HuggingFaceCrossEncoder
25
  from langchain_core.documents import Document
26
 
27
  # --- CẤU HÌNH ---
@@ -38,58 +38,75 @@ def get_category_vn_name(cat_code):
38
  "association": "🌐 Hiệp Hội"
39
  }.get(cat_code, "Khác")
40
 
41
- # --- 2. LOAD DB ĐÃ (KHÔNG BUILD LẠI) ---
42
  def get_retrievers():
43
  if not os.path.exists(DB_PATH):
44
- # Lỗi phổ biến nhất: Quên upload hoặc upload sai chỗ
45
  raise FileNotFoundError(f"❌ LỖI: Không tìm thấy thư mục '{DB_PATH}'. Bạn đã upload folder này vào phần Files chưa?")
46
 
47
- logging.info("--- Đang tải dữ liệu... ---")
48
  embedding = HuggingFaceEmbeddings(model_name="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2")
49
  vectorstore = Chroma(persist_directory=DB_PATH, embedding_function=embedding)
50
 
51
- # Tái tạo BM25 từ VectorStore
52
- all_data = vectorstore.get()
53
-
54
- # Kiểm tra nếu DB rỗng
55
- if not all_data['documents']:
56
- raise ValueError("❌ LỖI: Database rỗng! thể quá trình build_db ở máy local bị lỗi hoặc chưa có file nào được xử lý.")
57
-
58
- splits = [Document(page_content=txt, metadata=m) for txt, m in zip(all_data['documents'], all_data['metadatas'])]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
- # Mode 1: FAST (Chỉ thuốc)
61
- vec_fast = vectorstore.as_retriever(search_kwargs={"k": 5, "filter": {"category": "drug_info"}})
62
  drug_docs = [d for d in splits if d.metadata.get("category") == "drug_info"]
63
- bm25_fast = BM25Retriever.from_documents(drug_docs) if drug_docs else None
64
- if bm25_fast: bm25_fast.k = 5
65
-
66
- fast_retriever = EnsembleRetriever(retrievers=[bm25_fast, vec_fast], weights=[0.4, 0.6]) if bm25_fast else vec_fast
67
 
68
- # Mode 2: DEEP (Ưu tiên Thanh Ba)
 
69
  cats = ["local_regimen", "moh_regimen", "association", "drug_info"]
70
- vec_deep = vectorstore.as_retriever(search_kwargs={"k": 25, "filter": {"category": {"$in": cats}}})
71
  deep_docs = [d for d in splits if d.metadata.get("category") in cats]
72
- bm25_deep = BM25Retriever.from_documents(deep_docs) if deep_docs else None
73
- if bm25_deep: bm25_deep.k = 25
74
-
75
- ensemble = EnsembleRetriever(retrievers=[bm25_deep, vec_deep], weights=[0.5, 0.5]) if bm25_deep else vec_deep
76
-
77
- # Rerank
78
- reranker = HuggingFaceCrossEncoder(model_name="BAAI/bge-reranker-v2-m3")
79
- compressor = CrossEncoderReranker(model=reranker, top_n=10)
80
- deep_retriever = ContextualCompressionRetriever(base_compressor=compressor, base_retriever=ensemble)
81
 
 
 
 
 
 
 
 
 
 
82
  return fast_retriever, deep_retriever
83
 
84
  # --- 3. BOT LOGIC ---
85
  class DeepMedBot:
86
  def __init__(self):
87
  self.ready = False
88
- self.init_error = "Đang khởi động..." # Lưu lỗi để hiển thị ra màn hình chat
89
 
90
  if not GOOGLE_API_KEY:
91
- self.init_error = "❌ LỖI: Chưa API Key.\nVui lòng vào Settings -> Variables and secrets -> New secret.\nName: GOOGLE_API_KEY\nValue: (Key của bạn)"
92
- logging.error(self.init_error)
93
  return
94
 
95
  try:
@@ -97,13 +114,14 @@ class DeepMedBot:
97
  self.llm = ChatGoogleGenerativeAI(model="gemini-1.5-flash", temperature=0.2, google_api_key=GOOGLE_API_KEY)
98
  self._build_chains()
99
  self.ready = True
100
- self.init_error = "" # Xóa lỗi nếu thành công
 
101
  except Exception as e:
102
- self.init_error = f"❌ LỖI KHỞI TẠO: {str(e)}"
103
  logging.error(self.init_error)
104
 
105
  def _build_chains(self):
106
- # Prompt Nhanh (Bảng Thuốc)
107
  fast_sys = (
108
  "Bạn là Dược sĩ Lâm sàng.\n"
109
  "Tra cứu [💊 Thuốc Nội Bộ] và trả lời bằng **Bảng Markdown**:\n"
@@ -115,7 +133,7 @@ class DeepMedBot:
115
  fast_chain = create_stuff_documents_chain(self.llm, ChatPromptTemplate.from_messages([("system", fast_sys), ("human", "{input}")]))
116
  self.fast_chain = create_retrieval_chain(self.fast_retriever, fast_chain)
117
 
118
- # Prompt Chuyên sâu (Phác đồ Thanh Ba + Bảng)
119
  deep_sys = (
120
  "Bạn là Bác sĩ Trưởng khoa.\n"
121
  "1. **Tìm phác đồ:** Ưu tiên tuyệt đối [🏥 Phác Đồ Thanh Ba]. Nếu không có mới dùng [Bộ Y Tế].\n"
@@ -132,18 +150,19 @@ class DeepMedBot:
132
  self.deep_chain = create_retrieval_chain(self.deep_retriever, deep_chain)
133
 
134
  def chat(self, msg, history, mode):
135
- # Nếu chưa sẵn sàng, trả về chính xác lỗi gì cho người dùng biết
136
  if not self.ready:
137
- return f"⚠️ HỆ THỐNG CHƯA SẴN SÀNG.\n\nNguyên nhân lỗi:\n{self.init_error}\n\nVui lòng kiểm tra lại cấu hình trên Hugging Face."
138
 
139
  chain = self.deep_chain if mode == "Chuyên sâu" else self.fast_chain
140
- res = chain.invoke({"input": msg})
141
-
142
- ans = res['answer']
143
- if 'context' in res and res['context']:
144
- refs = list(set([f"- [{get_category_vn_name(d.metadata.get('category'))}] {d.metadata.get('source')}" for d in res['context']]))
145
- ans += "\n\n---\n📚 **Nguồn:**\n" + "\n".join(refs)
146
- return ans
 
 
147
 
148
  bot = DeepMedBot()
149
 
 
3
  import logging
4
  import gradio as gr
5
 
6
+ # --- 1. SỬA LỖI SQLITE ---
7
  try:
8
  __import__("pysqlite3")
9
  sys.modules["sqlite3"] = sys.modules.pop("pysqlite3")
10
  except ImportError:
11
+ pass
12
 
13
  import chromadb
14
  from langchain_google_genai import ChatGoogleGenerativeAI
15
  from langchain_chroma import Chroma
16
  from langchain_huggingface import HuggingFaceEmbeddings
17
+ # CẬP NHẬT IMPORT: Import trực tiếp từ langchain.retrievers để tránh lỗi ModuleNotFoundError
18
+ from langchain.retrievers import EnsembleRetriever, ContextualCompressionRetriever
19
  from langchain_community.retrievers import BM25Retriever
20
+ from langchain.retrievers.document_compressors import CrossEncoderReranker
21
+ from langchain_community.cross_encoders import HuggingFaceCrossEncoder
22
  from langchain.chains import create_retrieval_chain
23
  from langchain.chains.combine_documents import create_stuff_documents_chain
24
  from langchain_core.prompts import ChatPromptTemplate
 
 
 
25
  from langchain_core.documents import Document
26
 
27
  # --- CẤU HÌNH ---
 
38
  "association": "🌐 Hiệp Hội"
39
  }.get(cat_code, "Khác")
40
 
41
+ # --- 2. LOAD DB VỚI CHẾ AN TOÀN (SAFE LOAD) ---
42
  def get_retrievers():
43
  if not os.path.exists(DB_PATH):
 
44
  raise FileNotFoundError(f"❌ LỖI: Không tìm thấy thư mục '{DB_PATH}'. Bạn đã upload folder này vào phần Files chưa?")
45
 
46
+ logging.info("--- Đang tải dữ liệu từ ChromaDB... ---")
47
  embedding = HuggingFaceEmbeddings(model_name="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2")
48
  vectorstore = Chroma(persist_directory=DB_PATH, embedding_function=embedding)
49
 
50
+ # Kiểm tra dữ liệu
51
+ try:
52
+ all_data = vectorstore.get()
53
+ if not all_data['documents']:
54
+ raise ValueError("Database rỗng")
55
+ # Tái tạo Documents để dùng cho BM25
56
+ splits = [Document(page_content=txt, metadata=m) for txt, m in zip(all_data['documents'], all_data['metadatas'])]
57
+ except Exception as e:
58
+ logging.error(f"Lỗi đọc dữ liệu Chroma: {e}")
59
+ raise ValueError(f"Không thể đọc dữ liệu từ ChromaDB: {e}")
60
+
61
+ # --- HÀM TẠO RETRIEVER AN TOÀN ---
62
+ def create_safe_retriever(k_val, filter_dict, doc_subset):
63
+ # 1. Luôn tạo Vector Retriever (Cái này ít lỗi nhất)
64
+ vec_retriever = vectorstore.as_retriever(search_kwargs={"k": k_val, "filter": filter_dict})
65
+
66
+ # 2. Cố gắng tạo Hybrid (BM25 + Vector)
67
+ try:
68
+ if not doc_subset:
69
+ return vec_retriever # Không có docs thì dùng vector thôi
70
+
71
+ bm25 = BM25Retriever.from_documents(doc_subset)
72
+ bm25.k = k_val
73
+ # Thử tạo Ensemble (Đây là chỗ hay gây lỗi '_type')
74
+ ensemble = EnsembleRetriever(retrievers=[bm25, vec_retriever], weights=[0.4, 0.6])
75
+ return ensemble
76
+ except Exception as e:
77
+ logging.warning(f"⚠️ Không thể tạo Hybrid Search (Lỗi: {e}). Đang chuyển sang dùng Vector Search thuần túy.")
78
+ return vec_retriever # Fallback về Vector nếu Hybrid lỗi
79
 
80
+ # Mode 1: FAST
81
+ logging.info("--- Khởi tạo Fast Retriever ---")
82
  drug_docs = [d for d in splits if d.metadata.get("category") == "drug_info"]
83
+ fast_retriever = create_safe_retriever(5, {"category": "drug_info"}, drug_docs)
 
 
 
84
 
85
+ # Mode 2: DEEP
86
+ logging.info("--- Khởi tạo Deep Retriever ---")
87
  cats = ["local_regimen", "moh_regimen", "association", "drug_info"]
 
88
  deep_docs = [d for d in splits if d.metadata.get("category") in cats]
89
+ base_deep_retriever = create_safe_retriever(25, {"category": {"$in": cats}}, deep_docs)
 
 
 
 
 
 
 
 
90
 
91
+ # 3. Thử tạo Reranker (Cũng có thể gây lỗi '_type')
92
+ try:
93
+ reranker = HuggingFaceCrossEncoder(model_name="BAAI/bge-reranker-v2-m3")
94
+ compressor = CrossEncoderReranker(model=reranker, top_n=10)
95
+ deep_retriever = ContextualCompressionRetriever(base_compressor=compressor, base_retriever=base_deep_retriever)
96
+ except Exception as e:
97
+ logging.warning(f"⚠️ Không thể tải Reranker (Lỗi: {e}). Dùng bộ tìm kiếm cơ bản.")
98
+ deep_retriever = base_deep_retriever
99
+
100
  return fast_retriever, deep_retriever
101
 
102
  # --- 3. BOT LOGIC ---
103
  class DeepMedBot:
104
  def __init__(self):
105
  self.ready = False
106
+ self.init_error = "Đang khởi động..."
107
 
108
  if not GOOGLE_API_KEY:
109
+ self.init_error = "❌ LỖI: Chưa cấu hình GOOGLE_API_KEY trong Settings."
 
110
  return
111
 
112
  try:
 
114
  self.llm = ChatGoogleGenerativeAI(model="gemini-1.5-flash", temperature=0.2, google_api_key=GOOGLE_API_KEY)
115
  self._build_chains()
116
  self.ready = True
117
+ self.init_error = ""
118
+ logging.info("✅ BOT KHỞI ĐỘNG THÀNH CÔNG!")
119
  except Exception as e:
120
+ self.init_error = f"❌ LỖI KHỞI TẠO NGHIÊM TRỌNG: {str(e)}"
121
  logging.error(self.init_error)
122
 
123
  def _build_chains(self):
124
+ # Prompt Nhanh
125
  fast_sys = (
126
  "Bạn là Dược sĩ Lâm sàng.\n"
127
  "Tra cứu [💊 Thuốc Nội Bộ] và trả lời bằng **Bảng Markdown**:\n"
 
133
  fast_chain = create_stuff_documents_chain(self.llm, ChatPromptTemplate.from_messages([("system", fast_sys), ("human", "{input}")]))
134
  self.fast_chain = create_retrieval_chain(self.fast_retriever, fast_chain)
135
 
136
+ # Prompt Chuyên sâu
137
  deep_sys = (
138
  "Bạn là Bác sĩ Trưởng khoa.\n"
139
  "1. **Tìm phác đồ:** Ưu tiên tuyệt đối [🏥 Phác Đồ Thanh Ba]. Nếu không có mới dùng [Bộ Y Tế].\n"
 
150
  self.deep_chain = create_retrieval_chain(self.deep_retriever, deep_chain)
151
 
152
  def chat(self, msg, history, mode):
 
153
  if not self.ready:
154
+ return f"⚠️ HỆ THỐNG GẶP LỖI.\n\nChi tiết lỗi:\n{self.init_error}\n\nHãy thử Restart Space trong phần Settings."
155
 
156
  chain = self.deep_chain if mode == "Chuyên sâu" else self.fast_chain
157
+ try:
158
+ res = chain.invoke({"input": msg})
159
+ ans = res['answer']
160
+ if 'context' in res and res['context']:
161
+ refs = list(set([f"- [{get_category_vn_name(d.metadata.get('category'))}] {d.metadata.get('source')}" for d in res['context']]))
162
+ ans += "\n\n---\n📚 **Nguồn:**\n" + "\n".join(refs)
163
+ return ans
164
+ except Exception as e:
165
+ return f"❌ Lỗi khi trả lời: {str(e)}"
166
 
167
  bot = DeepMedBot()
168