Mega_QA / app.py
adamtobegreat's picture
Update app.py
812867d verified
raw
history blame
7.53 kB
"""
======================================================
📘 金融客服小智(Fintech Assistant)
版本:v2.1 (Hugging Face 部署版)
改進重點:
1. 改用記憶體型 Chroma,避免 PersistentClient 錯誤
2. 路徑使用 os.getcwd() 以符合 HF Spaces
3. 加入 QA 檔案容錯與模擬模式
4. GOOGLE_API_KEY 以 Secrets 管理
======================================================
"""
import os, re, base64
import chromadb
import gradio as gr
from langchain_core.documents import Document
from langchain_chroma import Chroma
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_google_genai import ChatGoogleGenerativeAI
# === 記憶模組相容多版本 ===
try:
from langchain_memory import ConversationBufferMemory
except ImportError:
try:
from langchain.memory import ConversationBufferMemory
except ImportError:
from langchain_community.memory import ConversationBufferMemory
# =============================================
# 1️⃣ Embedding 與基礎設定
# =============================================
embedding = HuggingFaceEmbeddings(model_name="BAAI/bge-small-zh-v1.5")
BASE_DIR = os.getcwd()
QA_PATH = os.path.join(BASE_DIR, "QA_v2.txt")
LOGO_PATH = os.path.join(BASE_DIR, "mega.png")
API_KEY = os.getenv("GOOGLE_API_KEY")
if not API_KEY:
print("⚠️ 尚未設定 GOOGLE_API_KEY,將使用模擬模式。")
# =============================================
# 2️⃣ QA 載入與分類
# =============================================
def load_qa_documents(path: str):
with open(path, "r", encoding="utf-8") as f:
text = f.read()
pattern = r"(Q[::].*?A[::].*?)(?=Q[::]|$)"
qas = re.findall(pattern, text, flags=re.S)
categories = {"證券": [], "期貨": [], "複委託": []}
for qa in qas:
doc = Document(page_content=qa.strip())
if "證券" in qa:
categories["證券"].append(doc)
elif "期貨" in qa:
categories["期貨"].append(doc)
elif "複委託" in qa:
categories["複委託"].append(doc)
else:
categories["證券"].append(doc)
return categories
if os.path.exists(QA_PATH):
qa_docs = load_qa_documents(QA_PATH)
print("✅ 已載入 QA 檔案,共分為:", {k: len(v) for k, v in qa_docs.items()})
else:
print("⚠️ 未找到 QA_v2.txt,啟用空白知識庫模式。")
qa_docs = {"證券": [], "期貨": [], "複委託": []}
# =============================================
# 3️⃣ 向量資料庫初始化(記憶體型)
# =============================================
try:
client = chromadb.Client()
except Exception:
import chromadb.api
client = chromadb.api.Client()
collection_map = {"證券": "stocks", "期貨": "futures", "複委託": "overseas"}
vectordbs = {}
for cat, docs in qa_docs.items():
vectordb = Chroma(client=client, collection_name=collection_map[cat], embedding_function=embedding)
if hasattr(vectordb._collection, "count") and vectordb._collection.count() == 0 and docs:
vectordb.add_documents(docs)
vectordbs[cat] = vectordb
print("✅ 向量資料庫初始化完成。")
# =============================================
# 4️⃣ 初始化 LLM 與記憶體
# =============================================
if API_KEY:
llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", google_api_key=API_KEY)
else:
llm = None # 模擬模式
memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
# =============================================
# 5️⃣ 對話邏輯
# =============================================
def auto_detect_category(text: str):
if any(k in text for k in ["股票", "證券", "開戶", "下單", "交割"]):
return "證券"
elif any(k in text for k in ["期貨", "選擇權", "保證金"]):
return "期貨"
elif any(k in text for k in ["複委託", "海外", "美股", "港股"]):
return "複委託"
return "證券"
def chat_fn(message, history):
category = auto_detect_category(message)
vectordb = vectordbs[category]
docs = vectordb.similarity_search(message, k=2)
context = "\n\n".join(d.page_content for d in docs) if docs else "查無相關資料"
prompt = f"""
你是一位金融客服人員,請根據以下QA知識回答:
---
{context}
---
使用者問題:{message}
"""
try:
if llm:
response = llm.invoke(prompt)
reply = getattr(response, "content", None) or getattr(response, "text", "⚠️ 無回覆")
else:
reply = "(模擬模式)這是示範回覆,請確認已設定 GOOGLE_API_KEY。"
except Exception as e:
reply = f"⚠️ 生成錯誤:{e}"
memory.save_context({"role": "user", "content": message},
{"role": "assistant", "content": reply})
return reply
# =============================================
# 6️⃣ Gradio 介面
# =============================================
logo_base64 = ""
if os.path.exists(LOGO_PATH):
with open(LOGO_PATH, "rb") as f:
logo_base64 = base64.b64encode(f.read()).decode("utf-8")
with gr.Blocks(
theme="soft",
css="""
#logo-top {
position: fixed; top: 12px; left: 18px;
background-color: white; border-radius: 10px;
padding: 6px 8px; box-shadow: 0 0 8px rgba(0,0,0,0.15);
pointer-events: none;
}
#logo-top img { width: 120px; height: auto; display: block; }
#footer { text-align:center; font-size:13px; color:#aaa; margin-top: 20px; }
"""
) as demo:
if logo_base64:
gr.HTML(f"<div id='logo-top'><img src='data:image/png;base64,{logo_base64}'></div>")
gr.Markdown("## 👨‍💼 我是小智 · 您的金融好幫手 🫰")
gr.Markdown("Powered by Gemini & LangChain")
with gr.Row():
with gr.Column(scale=4):
chatbot = gr.Chatbot(label="💬 對話紀錄", type="messages", height=500)
user_input = gr.Textbox(placeholder="請輸入問題...", show_label=False)
send_btn = gr.Button("送出", variant="primary")
def handle_input(message, history):
if not message.strip():
return history, gr.update(value="")
reply = chat_fn(message, history)
history += [{"role": "user", "content": message},
{"role": "assistant", "content": reply}]
return history, gr.update(value="")
user_input.submit(handle_input, [user_input, chatbot], [chatbot, user_input])
send_btn.click(handle_input, [user_input, chatbot], [chatbot, user_input])
gr.Button("🧹 清除對話").click(lambda: ([], gr.update(value="")), outputs=[chatbot, user_input])
with gr.Column(scale=1):
gr.Markdown("### 🔍 常見問題")
examples = [
"未成年可以開戶嗎?",
"法人開戶要準備什麼?",
"期貨交易保證金是什麼?",
"複委託要如何下單?",
"美股交易時間?",
"美股可以定期定額嗎?"
]
for q in examples:
gr.Button(q).click(lambda h, q=q: handle_input(q, h), [chatbot], [chatbot, user_input])
gr.HTML("<div id='footer'>© Fintech Assistant — 僅業務使用,非官方授權</div>")
demo.launch()