adamtobegreat commited on
Commit
bc2dd3b
·
verified ·
1 Parent(s): 6d871f0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +149 -23
app.py CHANGED
@@ -99,6 +99,142 @@ def chat_fn(message, history):
99
  # =============================================
100
  # 6️⃣ Gradio 介面
101
  # =============================================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  logo_path = os.path.join(BASE_DIR, "mega.png")
103
  logo_base64 = ""
104
  if os.path.exists(logo_path):
@@ -155,28 +291,22 @@ with gr.Blocks(
155
  box-shadow: inset 0 0 1px rgba(0,0,0,0.05);
156
  }
157
 
158
- /* 🟢 小巧箭頭按鈕(含飛出動畫) */
159
  #send-btn {
160
- background-color: #00b800; border: none; border-radius: 50%;
161
- width: 28px; height: 28px; cursor: pointer;
162
- display: flex; align-items: center; justify-content: center;
 
 
 
 
 
 
163
  transition: background-color 0.2s ease, transform 0.1s ease;
164
- box-shadow: 0 1px 2px rgba(0,0,0,0.1); padding: 0; overflow: hidden;
165
- }
166
- #send-btn svg {
167
- width: 12px; height: 12px; fill: white;
168
- transition: transform 0.25s ease;
169
  }
170
  #send-btn:hover { background-color: #00a000; }
171
- #send-btn:hover svg { transform: rotate(10deg) scale(1.15); }
172
- #send-btn:active { transform: scale(0.9); }
173
- #send-btn:active svg { animation: send-fly 0.5s ease-out; }
174
-
175
- @keyframes send-fly {
176
- 0% { transform: translateX(0) scale(1); opacity: 1; }
177
- 50% { transform: translateX(8px) scale(1.2); opacity: 0; }
178
- 100% { transform: translateX(0) scale(1); opacity: 1; }
179
- }
180
  """
181
  ) as demo:
182
  # 左上角 logo
@@ -200,11 +330,7 @@ with gr.Blocks(
200
 
201
  with gr.Row(elem_id="input-row"):
202
  user_input = gr.Textbox(elem_id="user-input", show_label=False, placeholder="輸入訊息...", scale=8)
203
- send_btn = gr.Button(
204
- value="""
205
- <svg viewBox="0 0 24 24"><path d="M3 12l18-9-6 9 6 9z"/></svg>
206
- """, elem_id="send-btn", scale=1
207
- )
208
 
209
  def handle_input(message, history):
210
  if not message.strip():
 
99
  # =============================================
100
  # 6️⃣ Gradio 介面
101
  # =============================================
102
+ import os, re, base64
103
+ from langchain_core.documents import Document
104
+ from langchain_chroma import Chroma
105
+ from openai import OpenAI
106
+ from langchain.embeddings.base import Embeddings
107
+ from langchain_google_genai import ChatGoogleGenerativeAI
108
+ import chromadb
109
+ import gradio as gr
110
+
111
+ # === 記憶模組相容多版本 ===
112
+ try:
113
+ from langchain_memory import ConversationBufferMemory
114
+ except ImportError:
115
+ try:
116
+ from langchain.memory import ConversationBufferMemory
117
+ except ImportError:
118
+ from langchain_community.memory import ConversationBufferMemory
119
+
120
+
121
+ # =============================================
122
+ # 1️⃣ 自訂 LM Studio Embedding 類別
123
+ # =============================================
124
+ class LmStudioEmbeddings(Embeddings):
125
+ def __init__(self, model_name, url):
126
+ self.model_name = model_name
127
+ self.client = OpenAI(base_url=url, api_key="lm-studio")
128
+
129
+ def embed_query(self, text: str):
130
+ res = self.client.embeddings.create(input=text, model=self.model_name)
131
+ return res.data[0].embedding
132
+
133
+ def embed_documents(self, texts: list[str]):
134
+ res = self.client.embeddings.create(input=texts, model=self.model_name)
135
+ return [x.embedding for x in res.data]
136
+
137
+
138
+ # =============================================
139
+ # 2️⃣ 載入 QA 檔案並分類
140
+ # =============================================
141
+ BASE_DIR = os.path.dirname(os.path.abspath(__file__))
142
+ qa_path = os.path.join(BASE_DIR, "QA_v2.txt")
143
+
144
+ if not os.path.exists(qa_path):
145
+ raise FileNotFoundError(f"❌ 找不到 QA 檔案:{qa_path}")
146
+
147
+ with open(qa_path, "r", encoding="utf-8") as f:
148
+ text = f.read()
149
+
150
+ pattern = r"(Q[::].*?)(?=Q[::]|$)"
151
+ qas = re.findall(pattern, text, flags=re.S)
152
+ qa_docs = {"證券": [], "期貨": [], "複委託": []}
153
+ for qa in qas:
154
+ if "證券" in qa:
155
+ qa_docs["證券"].append(Document(page_content=qa.strip()))
156
+ elif "期貨" in qa:
157
+ qa_docs["期貨"].append(Document(page_content=qa.strip()))
158
+ elif "複委託" in qa:
159
+ qa_docs["複委託"].append(Document(page_content=qa.strip()))
160
+
161
+ print("✅ 已成功讀取 QA 並完成分類:")
162
+ for k, v in qa_docs.items():
163
+ print(f" {k}:{len(v)} 筆")
164
+
165
+
166
+ # =============================================
167
+ # 3️⃣ 建立向量資料庫
168
+ # =============================================
169
+ embedding = LmStudioEmbeddings(
170
+ model_name="text-embedding-bge-large-zh-v1.5",
171
+ url="http://127.0.0.1:1234/v1"
172
+ )
173
+ client = chromadb.PersistentClient(path="./chroma_db")
174
+
175
+ collection_names = {"證券": "stocks", "期貨": "futures", "複委託": "overseas"}
176
+ vectordbs = {}
177
+ for cat, docs in qa_docs.items():
178
+ vectordbs[cat] = Chroma(
179
+ client=client,
180
+ collection_name=collection_names[cat],
181
+ embedding_function=embedding
182
+ )
183
+ if len(vectordbs[cat].get()["documents"]) == 0:
184
+ vectordbs[cat].add_documents(docs)
185
+ print("✅ 各類別向量資料庫建立完成")
186
+
187
+
188
+ # =============================================
189
+ # 4️⃣ 初始化 Gemini LLM
190
+ # =============================================
191
+ API_KEY = os.getenv("GOOGLE_API_KEY")
192
+ if not API_KEY:
193
+ raise ValueError("⚠️ 未設定 GOOGLE_API_KEY,請在 Hugging Face Secrets 中新增。")
194
+
195
+ llm = ChatGoogleGenerativeAI(model="gemini-2.5-flash", google_api_key=API_KEY)
196
+ memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
197
+
198
+
199
+ # =============================================
200
+ # 5️⃣ 對話邏輯
201
+ # =============================================
202
+ def auto_detect_category(text):
203
+ if any(k in text for k in ["股票", "證券", "開戶", "下單", "交割"]):
204
+ return "證券"
205
+ elif any(k in text for k in ["期貨", "選擇權", "保證金"]):
206
+ return "期貨"
207
+ elif any(k in text for k in ["複委託", "海外", "美股", "港股"]):
208
+ return "複委託"
209
+ else:
210
+ return "證券"
211
+
212
+
213
+ def chat_fn(message, history):
214
+ category = auto_detect_category(message)
215
+ vectordb = vectordbs.get(category)
216
+ docs = vectordb.similarity_search(message, k=2)
217
+ context = "\n\n".join([d.page_content for d in docs]) if docs else "查無資料"
218
+
219
+ prompt = f"""
220
+ 你是一位金融客服人員,根據以下公司QA回答客戶問題:
221
+ ---
222
+ {context}
223
+ ---
224
+ 使用者問題:{message}
225
+ """
226
+
227
+ try:
228
+ response = llm.invoke(prompt)
229
+ reply = response.content.strip()
230
+ except Exception as e:
231
+ reply = f"⚠️ 生成錯誤:{e}"
232
+ return reply or "請洽營業員"
233
+
234
+
235
+ # =============================================
236
+ # 6️⃣ 介面(LINE風格 + 純白footer + 小輸入按鈕)
237
+ # =============================================
238
  logo_path = os.path.join(BASE_DIR, "mega.png")
239
  logo_base64 = ""
240
  if os.path.exists(logo_path):
 
291
  box-shadow: inset 0 0 1px rgba(0,0,0,0.05);
292
  }
293
 
294
+ /* 🟢 小巧文字版「輸入」按鈕 */
295
  #send-btn {
296
+ background-color: #00b800;
297
+ color: white;
298
+ border: none;
299
+ border-radius: 14px;
300
+ height: 26px;
301
+ padding: 0 10px;
302
+ font-size: 13px;
303
+ font-weight: 600;
304
+ cursor: pointer;
305
  transition: background-color 0.2s ease, transform 0.1s ease;
306
+ box-shadow: 0 1px 2px rgba(0,0,0,0.1);
 
 
 
 
307
  }
308
  #send-btn:hover { background-color: #00a000; }
309
+ #send-btn:active { transform: scale(0.95); }
 
 
 
 
 
 
 
 
310
  """
311
  ) as demo:
312
  # 左上角 logo
 
330
 
331
  with gr.Row(elem_id="input-row"):
332
  user_input = gr.Textbox(elem_id="user-input", show_label=False, placeholder="輸入訊息...", scale=8)
333
+ send_btn = gr.Button("輸入", elem_id="send-btn", scale=1)
 
 
 
 
334
 
335
  def handle_input(message, history):
336
  if not message.strip():