import os import sys import torch import streamlit as st from transformers import AutoTokenizer ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) sys.path.append(os.path.join(ROOT_DIR, 'src')) sys.path.append(os.path.join(ROOT_DIR, 'src', 'NLU')) sys.path.append(os.path.join(ROOT_DIR, 'src', 'rag')) from rag.predict import predict, load_encoder from model_intent import JointPhoBERTModel from main import build_prompt, generate_answer, setup_openai, setup_hf_client, load_env_vars, load_bm25_retriever, reciprocal_rank_fusion from langchain_huggingface import HuggingFaceEmbeddings from langchain_chroma import Chroma st.set_page_config( page_title="Trợ lý Y tế AI - Chatbot RAG", page_icon="🏥", layout="wide", initial_sidebar_state="expanded" ) st.markdown(""" """, unsafe_allow_html=True) @st.cache_resource def load_nlu_resources(): device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model_name = "vinai/phobert-base-v2" train_path = os.path.join(ROOT_DIR, "data", "train.json") ckpt_path = os.path.join(ROOT_DIR, "src", "checkpoints", "best_joint_model.pth") if not os.path.exists(train_path): st.error(f"Không tìm thấy dữ liệu tập train tại `{train_path}`.") st.stop() if not os.path.exists(ckpt_path): st.warning(f"Chưa tìm thấy file checkpoint mô hình NLU tại `{ckpt_path}`. Vui lòng chạy huấn luyện trước bằng cách thực thi `python src/NLU/train_intent.py`.") return None, None, None, device encoder = load_encoder(train_path) tokenizer = AutoTokenizer.from_pretrained(model_name) model = JointPhoBERTModel(model_name, encoder.get_num_intents(), encoder.get_num_ner_tags()) model.load_state_dict(torch.load(ckpt_path, map_location=device)) model.to(device) model.eval() return model, tokenizer, encoder, device @st.cache_resource def load_vector_store(): device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') persist_dir = os.path.join(ROOT_DIR, "src", "rag", "chroma_db") if not os.path.exists(persist_dir): st.warning(f"Chưa tìm thấy Cơ sở dữ liệu Vector tại `{persist_dir}`. Vui lòng tạo DB trước bằng cách thực thi `python src/rag/offline_rag.py`.") return None embeddings = HuggingFaceEmbeddings( model_name="keepitreal/vietnamese-sbert", model_kwargs={'device': device} ) vector_db = Chroma(persist_directory=persist_dir, embedding_function=embeddings) return vector_db @st.cache_resource def load_bm25_store(): return load_bm25_retriever() # Khởi tạo mô hình và DB nlu_model, tokenizer, encoder, device = load_nlu_resources() vector_db = load_vector_store() bm25_retriever = load_bm25_store() # Đọc API Key mặc định từ hệ thống hoặc .env env_vars = load_env_vars(os.path.join(ROOT_DIR, "src", ".env")) api_key_input = os.environ.get("OPENAI_API_KEY", env_vars.get("OPENAI_API_KEY", "")) api_base_input = os.environ.get("OPENAI_API_BASE", env_vars.get("OPENAI_API_BASE", "https://models.inference.ai.azure.com")) hf_token_input = os.environ.get("HF_TOKEN", env_vars.get("HF_TOKEN", "")) with st.sidebar: st.image("https://cdn-icons-png.flaticon.com/512/809/809957.png", width=80) st.markdown("### Câu hỏi gợi ý") sample_queries = [ "Bé bị sốt cao co giật phải làm sao?", "Triệu chứng khi trẻ bị viêm phổi?", "Khi nào cần đưa trẻ bị tiêu chảy đi khám ngay?" ] for q in sample_queries: if st.button(q, use_container_width=True, key=q): st.session_state.suggested_query = q st.divider() if st.button("Xóa lịch sử trò chuyện", use_container_width=True): st.session_state.messages = [] st.rerun() st.markdown("""

🏥 TRỢ LÝ Y TẾ AI THÔNG MINH

Kết hợp hiểu ý định người dùng (PhoBERT NLU) & Cơ sở dữ liệu tài liệu y khoa chính thống (RAG)

""", unsafe_allow_html=True) if "messages" not in st.session_state: st.session_state.messages = [] # Hiển thị các tin nhắn trong lịch sử for message in st.session_state.messages: with st.chat_message(message["role"]): st.write(message["content"]) if "analysis" in message: analysis = message["analysis"] with st.expander("🔍 Chi tiết phân tích hệ thống"): col1, col2 = st.columns(2) with col1: st.markdown("##### Phân tích Ý định & Thực thể (NLU)") st.markdown(f"**Ý định (Intent):** {analysis['intent']}", unsafe_allow_html=True) if analysis['entities']: st.markdown("**Thực thể tìm thấy:**") for ent, tag in analysis['entities']: # Tạo màu sắc khác nhau dựa trên nhãn tag color = "#e91e63" if "DISEASE" in tag or "B-" in tag else "#4caf50" st.markdown(f"{ent} ({tag})", unsafe_allow_html=True) else: st.markdown("*Không phát hiện thực thể y tế cụ thể.*") with col2: st.markdown("##### Tài liệu tra cứu (RAG)") if analysis['docs']: for idx, doc in enumerate(analysis['docs']): source = os.path.basename(doc.get("source", "Tài liệu y tế")) st.markdown(f"""
Tài liệu {idx+1} (Nguồn: {source})
{doc.get("content")[:200]}...
""", unsafe_allow_html=True) else: st.markdown("*Không tìm thấy tài liệu liên quan.*") # Nhận câu hỏi mới query = None if "suggested_query" in st.session_state and st.session_state.suggested_query: query = st.session_state.suggested_query st.session_state.suggested_query = None # Reset else: query = st.chat_input("Nhập câu hỏi y tế của bạn tại đây...") if query: with st.chat_message("user"): st.write(query) st.session_state.messages.append({"role": "user", "content": query}) with st.chat_message("assistant"): with st.spinner("Đang phân tích câu hỏi và tra cứu y văn..."): # --- Bước 1: Phân tích NLU (Intent + NER) --- if nlu_model is not None: intent, entities = predict(query, nlu_model, tokenizer, encoder, device) entity_words = [word.replace('_', ' ') for word, tag in entities if tag != "O"] entities_filtered = [(word.replace('_', ' '), tag) for word, tag in entities if tag != "O"] else: intent = "Không rõ (Chưa load mô hình)" entity_words = [] entities_filtered = [] # --- Bước 2: RAG Retrieval --- retrieved_docs = [] retrieved_docs_serializable = [] if vector_db is not None: # Dùng phiên bản chữ thường và tăng cường (Query Expansion) search_query = query.strip().lower() if "là gì" in search_query or "thế nào là" in search_query: search_query += " đại cương định nghĩa khái niệm" vector_docs = vector_db.similarity_search(search_query, k=5) bm25_docs = [] if bm25_retriever: bm25_retriever.k = 5 if entity_words: bm25_query = " ".join(entity_words) bm25_docs = bm25_retriever.invoke(bm25_query) else: bm25_docs = bm25_retriever.invoke(search_query) retrieved_docs = reciprocal_rank_fusion(vector_docs, bm25_docs, top_n=5) for doc in retrieved_docs: retrieved_docs_serializable.append({ "source": doc.metadata.get("source", "Tài liệu y tế"), "content": doc.page_content }) # --- Bước 3: Sinh câu trả lời với LLM --- if not api_key_input and not hf_token_input: answer = "Không tìm thấy API Key nào. Vui lòng thiết lập biến môi trường OPENAI_API_KEY hoặc HF_TOKEN để kích hoạt trả lời từ AI." else: openai_client = setup_openai(api_key_input, api_base_input) if api_key_input else None hf_client = setup_hf_client(hf_token_input) if hf_token_input else None prompt = build_prompt(query, intent, entity_words, retrieved_docs) answer = generate_answer(openai_client, hf_client, prompt) # Hiển thị câu trả lời st.write(answer) # Lưu phân tích vào session state để hiển thị lại analysis_info = { "intent": intent, "entities": entities_filtered, "docs": retrieved_docs_serializable } # Đính kèm phân tích hệ thống vào tin nhắn cuối cùng để kết xuất st.session_state.messages.append({ "role": "assistant", "content": answer, "analysis": analysis_info }) st.rerun() st.markdown("---") st.caption("**Khuyến cáo y tế**: Thông tin cung cấp bởi chatbot chỉ mang tính chất tham khảo dựa trên tài liệu hướng dẫn y khoa chính thống có sẵn và không thay thế cho chẩn đoán, điều trị của bác sĩ chuyên khoa.")