Mega_QA / app.py
adamtobegreat's picture
Update app.py
5ba153f verified
raw
history blame
7.81 kB
"""
======================================================
📘 金融客服小智(Fintech Assistant)
版本:v2 (重構示範 by Supervisor)
改進重點:
1. 模組化程式結構(易維護)
2. 加入記憶體保存(多輪對話)
3. 改善 Chroma 初始化與 QA 擷取
4. 加強異常處理與容錯提示
======================================================
"""
import os, re, base64
import chromadb
import gradio as gr
from langchain_core.documents import Document
from langchain_chroma import Chroma
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_google_genai import ChatGoogleGenerativeAI
# === 記憶模組相容多版本 ===
try:
from langchain_memory import ConversationBufferMemory
except ImportError:
try:
from langchain.memory import ConversationBufferMemory
except ImportError:
from langchain_community.memory import ConversationBufferMemory
# =============================================
# 1️⃣ Embedding 與基礎設定
# =============================================
embedding = HuggingFaceEmbeddings(model_name="BAAI/bge-small-zh-v1.5")
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
QA_PATH = os.path.join(BASE_DIR, "QA_v2.txt")
LOGO_PATH = os.path.join(BASE_DIR, "mega.png")
if not os.path.exists(QA_PATH):
raise FileNotFoundError("❌ 找不到 QA 檔案 QA_v2.txt,請確認路徑。")
API_KEY = os.getenv("GOOGLE_API_KEY")
if not API_KEY:
print("⚠️ 尚未設定 GOOGLE_API_KEY,系統將以模擬回覆運行。")
# =============================================
# 2️⃣ QA 載入與分類(改進版正規化)
# =============================================
def load_qa_documents(path: str):
with open(path, "r", encoding="utf-8") as f:
text = f.read()
# 改進版正規表達式,確保每筆 QA 含問題與答案
pattern = r"(Q[::].*?A[::].*?)(?=Q[::]|$)"
qas = re.findall(pattern, text, flags=re.S)
categories = {"證券": [], "期貨": [], "複委託": []}
for qa in qas:
doc = Document(page_content=qa.strip())
if "證券" in qa:
categories["證券"].append(doc)
elif "期貨" in qa:
categories["期貨"].append(doc)
elif "複委託" in qa:
categories["複委託"].append(doc)
else:
categories["證券"].append(doc) # 預設分類
return categories
qa_docs = load_qa_documents(QA_PATH)
print("✅ 已成功載入 QA 檔案,共分為:", {k: len(v) for k, v in qa_docs.items()})
# =============================================
# 3️⃣ 向量資料庫初始化(避免重複寫入)
# =============================================
client = chromadb.PersistentClient(path="./chroma_db")
collection_map = {"證券": "stocks", "期貨": "futures", "複委託": "overseas"}
vectordbs = {}
for cat, docs in qa_docs.items():
vectordb = Chroma(
client=client,
collection_name=collection_map[cat],
embedding_function=embedding
)
if vectordb._collection.count() == 0:
vectordb.add_documents(docs)
vectordbs[cat] = vectordb
print("✅ 向量資料庫已建立完成。")
# =============================================
# 4️⃣ 初始化 LLM 與對話記憶
# =============================================
if API_KEY:
llm = ChatGoogleGenerativeAI(model="gemini-2.5-flash", google_api_key=API_KEY)
else:
llm = None # 模擬模式
memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
# =============================================
# 5️⃣ 對話邏輯
# =============================================
def auto_detect_category(text: str):
"""根據關鍵詞自動偵測使用者詢問的業務類別"""
if any(k in text for k in ["股票", "證券", "開戶", "下單", "交割"]):
return "證券"
elif any(k in text for k in ["期貨", "選擇權", "保證金"]):
return "期貨"
elif any(k in text for k in ["複委託", "海外", "美股", "港股"]):
return "複委託"
else:
return "證券"
def chat_fn(message, history):
"""核心對話函式"""
category = auto_detect_category(message)
vectordb = vectordbs[category]
docs = vectordb.similarity_search(message, k=2)
context = "\n\n".join(d.page_content for d in docs) if docs else "查無相關資料"
prompt = f"""
你是一位金融客服人員,請根據以下QA知識回答:
---
{context}
---
使用者問題:{message}
"""
try:
if llm:
response = llm.invoke(prompt)
reply = getattr(response, "content", None) or getattr(response, "text", "⚠️ 無回覆")
else:
reply = "(模擬模式)這是示範回覆:請確認是否已設定 GOOGLE_API_KEY。"
except Exception as e:
reply = f"⚠️ 生成錯誤:{e}"
# 保存對話記憶
memory.save_context({"role": "user", "content": message},
{"role": "assistant", "content": reply})
return reply
# =============================================
# 6️⃣ Gradio 介面(重構版)
# =============================================
logo_base64 = ""
if os.path.exists(LOGO_PATH):
with open(LOGO_PATH, "rb") as f:
logo_base64 = base64.b64encode(f.read()).decode("utf-8")
with gr.Blocks(
theme="soft",
css="""
#logo-top {
position: fixed; top: 12px; left: 18px;
background-color: white; border-radius: 10px;
padding: 6px 8px; box-shadow: 0 0 8px rgba(0,0,0,0.15);
pointer-events: none;
}
#logo-top img { width: 120px; height: auto; display: block; }
#footer { text-align:center; font-size:13px; color:#aaa; margin-top: 20px; }
"""
) as demo:
if logo_base64:
gr.HTML(f"<div id='logo-top'><img src='data:image/png;base64,{logo_base64}'></div>")
gr.Markdown("## 👨‍💼 我是小智 · 您的金融好幫手 🫰")
gr.Markdown("Powered by Gemini & LangChain")
with gr.Row():
with gr.Column(scale=4):
chatbot = gr.Chatbot(label="💬 對話紀錄", type="messages", height=500)
with gr.Row():
user_input = gr.Textbox(placeholder="請輸入問題...", show_label=False)
send_btn = gr.Button("送出", variant="primary")
def handle_input(message, history):
if not message.strip():
return history, gr.update(value="")
reply = chat_fn(message, history)
history += [{"role": "user", "content": message},
{"role": "assistant", "content": reply}]
return history, gr.update(value="")
user_input.submit(handle_input, [user_input, chatbot], [chatbot, user_input])
send_btn.click(handle_input, [user_input, chatbot], [chatbot, user_input])
def clear_all():
memory.clear()
return [], gr.update(value="")
gr.Button("🧹 清除對話").click(clear_all, outputs=[chatbot, user_input])
with gr.Column(scale=1):
gr.Markdown("### 🔍 常見問題")
examples = [
"未成年可以開戶嗎?",
"法人開戶要準備什麼?",
"期貨交易保證金是什麼?",
"複委託要如何下單?",
"美股交易時間?",
"美股可以定期定額嗎?"
]
for q in examples:
gr.Button(q).click(lambda h, q=q: handle_input(q, h), [chatbot], [chatbot, user_input])
gr.HTML("<div id='footer'>© Fintech Assistant — 僅業務使用,非官方授權</div>")
demo.launch()