Mega_QA / app.py
adamtobegreat's picture
Update app.py
a3bae88 verified
raw
history blame
12.1 kB
"""
======================================================
📘 金融客服小智(Fintech Assistant)
版本:v3.2 (穩定正式版)
更新重點:
1. 修正 LangChain 記憶格式(避免 ValueError)
2. 回復原生輸入框樣式(類似 LINE 的簡潔列)
3. 保留手機自適應、桌面置中、右欄清除鍵
======================================================
"""
import os, re, base64
import chromadb
import gradio as gr
from langchain_core.documents import Document
from langchain_chroma import Chroma
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_google_genai import ChatGoogleGenerativeAI
# === 記憶模組相容多版本 ===
try:
from langchain_memory import ConversationBufferMemory
except ImportError:
try:
from langchain.memory import ConversationBufferMemory
except ImportError:
from langchain_community.memory import ConversationBufferMemory
# =============================================
# 1️⃣ Embedding 與基礎設定
# =============================================
embedding = HuggingFaceEmbeddings(model_name="BAAI/bge-small-zh-v1.5")
BASE_DIR = os.getcwd()
QA_PATH = os.path.join(BASE_DIR, "QA_v2.txt")
LOGO_PATH = os.path.join(BASE_DIR, "mega.png")
API_KEY = os.getenv("GOOGLE_API_KEY")
if not API_KEY:
print("⚠️ 尚未設定 GOOGLE_API_KEY,系統將以模擬模式運行。")
# =============================================
# 2️⃣ QA 載入與分類
# =============================================
def load_qa_documents(path: str):
with open(path, "r", encoding="utf-8") as f:
text = f.read()
pattern = r"(Q[::].*?A[::].*?)(?=Q[::]|$)"
qas = re.findall(pattern, text, flags=re.S)
categories = {"證券": [], "期貨": [], "複委託": []}
for qa in qas:
doc = Document(page_content=qa.strip())
if "證券" in qa:
categories["證券"].append(doc)
elif "期貨" in qa:
categories["期貨"].append(doc)
elif "複委託" in qa:
categories["複委託"].append(doc)
else:
categories["證券"].append(doc)
return categories
if os.path.exists(QA_PATH):
qa_docs = load_qa_documents(QA_PATH)
print("✅ 已載入 QA 檔案,共分為:", {k: len(v) for k, v in qa_docs.items()})
else:
print("⚠️ 未找到 QA_v2.txt,啟用空白知識庫模式。")
qa_docs = {"證券": [], "期貨": [], "複委託": []}
# =============================================
# 3️⃣ 向量資料庫初始化
# =============================================
client = chromadb.Client()
collection_map = {"證券": "stocks", "期貨": "futures", "複委託": "overseas"}
vectordbs = {}
for cat, docs in qa_docs.items():
vectordb = Chroma(client=client, collection_name=collection_map[cat], embedding_function=embedding)
if hasattr(vectordb._collection, "count") and vectordb._collection.count() == 0 and docs:
vectordb.add_documents(docs)
vectordbs[cat] = vectordb
print("✅ 向量資料庫初始化完成。")
# =============================================
# 4️⃣ 初始化 LLM 與記憶體
# =============================================
if API_KEY:
llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", google_api_key=API_KEY)
else:
llm = None # 模擬模式
memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
# =============================================
# 5️⃣ 對話邏輯
# =============================================
def auto_detect_category(text: str):
if any(k in text for k in ["股票", "證券", "開戶", "下單", "交割"]):
return "證券"
elif any(k in text for k in ["期貨", "選擇權", "保證金"]):
return "期貨"
elif any(k in text for k in ["複委託", "海外", "美股", "港股"]):
return "複委託"
return "證券"
def chat_fn(message, history):
category = auto_detect_category(message)
vectordb = vectordbs[category]
docs = vectordb.similarity_search(message, k=2)
context = "\n\n".join(d.page_content for d in docs) if docs else "查無相關資料"
prompt = f"""
你是一位金融客服人員,請根據以下QA知識回答:
---
{context}
---
使用者問題:{message}
"""
try:
if llm:
response = llm.invoke(prompt)
reply = getattr(response, "content", None) or getattr(response, "text", "⚠️ 無回覆")
else:
reply = "(模擬模式)這是示範回覆,請確認是否已設定 GOOGLE_API_KEY。"
except Exception as e:
reply = f"⚠️ 生成錯誤:{e}"
# ✅ 修正記憶體格式,避免 ValueError
memory.save_context({"input": message}, {"output": reply})
return reply
# =============================================
# 6️⃣ Gradio 介面
# =============================================
# === Logo 圖片處理 ===
logo_base64 = ""
if os.path.exists(LOGO_PATH):
with open(LOGO_PATH, "rb") as f:
logo_base64 = base64.b64encode(f.read()).decode("utf-8")
with gr.Blocks(
theme="soft",
css="""
#logo-top {
position: fixed; top: 12px; left: 18px;
background-color: white; border-radius: 10px;
padding: 6px 8px; box-shadow: 0 0 8px rgba(0,0,0,0.15);
pointer-events: none;
}
#logo-top img { width: 120px; height: auto; display: block; }
#footer { text-align:center; font-size:13px; color:#aaa; margin-top: 20px; }
/* 手機寬度下讓 Row 自動垂直排列 */
@media (max-width: 768px) {
.gr-block.gr-row {
flex-direction: column !important;
}
#logo-top img { width: 90px; }
.gradio-container { padding: 8px; }
#footer { font-size: 12px; margin-top: 10px; }
}
/* === 桌機/手機自適應標題 === */
#main-title {
text-align: center;
font-weight: bold;
font-size: 26px;
margin-top: 60px;
margin-bottom: 6px;
}
.title-line {
display: flex;
justify-content: center;
align-items: center;
gap: 8px;
flex-wrap: nowrap;
}
.subtitle {
white-space: nowrap;
}
@media (max-width: 768px) {
.title-line {
flex-direction: column;
gap: 4px;
}
#main-title {
font-size: 22px;
line-height: 1.4;
}
}
/* ✅ 修正輸入框高度與按鈕比例 */
#chat-input textarea {
height: 48px !important;
min-height: 48px !important;
font-size: 16px !important;
padding: 8px 12px !important;
border-radius: 10px !important;
}
#chat-row {
align-items: center !important;
gap: 4px !important;
}
#send-btn {
height: 48px !important;
font-size: 16px !important;
border-radius: 10px !important;
}
/* ✅ 桌機版比例:輸入框 8、按鈕 2 */
#chat-row .gr-textbox, #chat-row textarea { flex: 9 !important; width: 90% !important; }
#chat-row .gr-button, #chat-row button { flex: 1 !important; width: 10% !important; }
/* ✅ 手機版比例:輸入框 9、按鈕 1(強制套用到 Hugging Face 結構) */
@media (max-width: 768px) {
#chat-row .gr-textbox, #chat-row textarea { flex: 9 !important; width: 90% !important; }
#chat-row .gr-button, #chat-row button { flex: 1 !important; width: 10% !important; max-width: 80px !important; min-width: 60px !important; }
#send-btn button { padding: 0 10px !important; }
}
"""
) as demo:
if logo_base64:
gr.HTML(f"<div id='logo-top'><img src='data:image/png;base64,{logo_base64}'></div>")
# 🔹 標題(桌機同行、手機自動換行)
gr.HTML("""
<div id="main-title">
<span class="title-line">
👨‍💼 我是小智
<span class="subtitle">您的金融好幫手 🫰</span>
</span>
</div>
""")
gr.Markdown("<div style='text-align:center; color:gray;'>Powered by Gemini & LangChain</div>")
with gr.Row(equal_height=False):
# 左側:聊天區
with gr.Column(scale=4, min_width=300):
chatbot = gr.Chatbot(label="💬 對話紀錄", type="messages", height=500)
# ✅ 輸入框與送出鍵同行排列(桌機、手機 9:1)
with gr.Row(elem_id="chat-row"):
user_input = gr.Textbox(
placeholder="請輸入您的問題(Enter 送出 / Shift+Enter 換行)...",
show_label=False,
lines=1,
max_lines=3,
elem_id="chat-input",
scale=9
)
send_btn = gr.Button(
"送出",
variant="primary",
elem_id="send-btn",
scale=1
)
# === 輸入邏輯 ===
def handle_input(message, history):
if history is None:
history = []
if not message.strip():
return history, gr.update(value="")
reply = chat_fn(message, history)
history += [
{"role": "user", "content": message},
{"role": "assistant", "content": reply}
]
return history, gr.update(value="")
# ✅ 綁定事件(Enter送出、Shift+Enter換行)
user_input.submit(handle_input, [user_input, chatbot], [chatbot, user_input])
send_btn.click(handle_input, [user_input, chatbot], [chatbot, user_input])
# ✅ JS 修正版:支援桌機 / 手機 / HuggingFace IFrame
gr.HTML("""
<script>
document.addEventListener("DOMContentLoaded", function() {
const observer = new MutationObserver(() => {
const textareas = document.querySelectorAll("textarea");
textareas.forEach((ta) => {
if (!ta.dataset.bound) {
ta.dataset.bound = "true";
ta.addEventListener("keydown", function(e) {
if (e.key === "Enter" && !e.shiftKey) {
e.preventDefault();
const sendBtn = document.querySelector('#send-btn button, #send-btn');
if (sendBtn) sendBtn.click();
}
});
}
});
});
observer.observe(document.body, { childList: true, subtree: true });
});
</script>
""")
# 右側:常見問題 + 清除對話
with gr.Column(scale=1, min_width=200):
gr.Markdown("### 🔍 常見問題")
examples = [
"未成年可以開戶嗎?",
"法人開戶要準備什麼?",
"期貨交易保證金是什麼?",
"複委託要如何下單?",
"美股交易時間?",
"美股可以定期定額嗎?"
]
for q in examples:
gr.Button(q).click(
fn=lambda q=q, history=[]: handle_input(q, history),
inputs=[],
outputs=[chatbot, user_input]
)
def clear_all():
memory.clear()
return [], gr.update(value="")
gr.Markdown("---")
gr.Button("🧹 清除對話").click(clear_all, outputs=[chatbot, user_input])
gr.HTML("<div id='footer'>© Fintech Assistant — 僅業務使用,非官方授權</div>")
demo.launch()