adamtobegreat commited on
Commit
9c88b52
·
verified ·
1 Parent(s): 92d18bd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -42
app.py CHANGED
@@ -6,16 +6,21 @@ from langchain.embeddings.base import Embeddings
6
  from langchain_google_genai import ChatGoogleGenerativeAI
7
  import chromadb
8
  import gradio as gr
9
- from langchain.chains import LLMChain
10
- from langchain.prompts import ChatPromptTemplate
 
 
 
 
11
 
12
  try:
13
- # 新版 LangChain
14
  from langchain.memory import ConversationBufferMemory
15
  except ImportError:
16
- # 舊版 LangChain(部分 Hugging Face 環境)
17
  from langchain_community.memory import ConversationBufferMemory
18
 
 
 
 
19
  # =============================================
20
  # 1️⃣ 自訂 LM Studio Embedding 類別
21
  # =============================================
@@ -32,15 +37,16 @@ class LmStudioEmbeddings(Embeddings):
32
  res = self.client.embeddings.create(input=texts, model=self.model_name)
33
  return [x.embedding for x in res.data]
34
 
 
35
  # =============================================
36
- # 2️⃣ 載入 QA 檔案並分類
37
  # =============================================
38
- import os
39
-
40
- # 自動取得目前執行檔案所在目錄
41
  BASE_DIR = os.path.dirname(os.path.abspath(__file__))
42
  path = os.path.join(BASE_DIR, "QA_v2.txt")
43
 
 
 
 
44
  with open(path, "r", encoding="utf-8") as f:
45
  text = f.read()
46
 
@@ -60,6 +66,7 @@ print("✅ 已成功讀取 QA 並完成分類:")
60
  for k, v in qa_docs.items():
61
  print(f" {k}:{len(v)} 筆")
62
 
 
63
  # =============================================
64
  # 3️⃣ 建立三個獨立向量資料庫
65
  # =============================================
@@ -84,15 +91,18 @@ for cat, docs in qa_docs.items():
84
 
85
  print("✅ 各類別向量資料庫建立完成")
86
 
 
87
  # =============================================
88
- # 4️⃣ 初始化 Gemini LLM + 記憶模組
89
  # =============================================
90
- API_KEY = "AIzaSyAxoIHYjStZ5xPe2EoNrOapHhvVmx9QzWs"
 
 
 
91
  llm = ChatGoogleGenerativeAI(model='gemini-2.5-flash', google_api_key=API_KEY)
92
 
93
  memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
94
 
95
- # ✅ 只保留一個變數 input,context 會手動插入文字中
96
  prompt = ChatPromptTemplate.from_messages([
97
  ("system", "你是一位金融客服人員,請根據下列公司規章內容回答使用者問題。若內容不足,也請根據既有資訊給出合理說明,並建議洽營業員了解詳情。"),
98
  ("human", "{input}")
@@ -104,6 +114,7 @@ chain = LLMChain(
104
  memory=memory
105
  )
106
 
 
107
  # =============================================
108
  # 5️⃣ 自動分類 + 對話主邏輯
109
  # =============================================
@@ -117,6 +128,7 @@ def auto_detect_category(text):
117
  else:
118
  return "證券"
119
 
 
120
  def chat_fn(message, history):
121
  print(f"[DEBUG] 問題:{message}")
122
 
@@ -132,7 +144,6 @@ def chat_fn(message, history):
132
  docs = vectordb.similarity_search(message, k=2)
133
  context = "\n\n".join([d.page_content for d in docs]) if docs else "目前查無相關內容。"
134
 
135
- # ✅ 將 context 手動整合進輸入文字中(新版 LangChain 安全寫法)
136
  full_input = f"公司規章內容如下:\n{context}\n\n使用者問題:{message}"
137
 
138
  try:
@@ -143,34 +154,20 @@ def chat_fn(message, history):
143
 
144
  return reply or "請洽營業員"
145
 
 
146
  # =============================================
147
  # 6️⃣ Gradio 介面 + 左上角 logo
148
  # =============================================
149
-
150
- """
151
- #要在HF上部署的話需要改ㄧ下api,把它藏起來
152
-
153
- import os
154
- from langchain_google_genai import ChatGoogleGenerativeAI
155
-
156
- API_KEY = os.getenv("GOOGLE_API_KEY")
157
- if not API_KEY:
158
- raise ValueError("⚠️ 未設定 GOOGLE_API_KEY,請在 Hugging Face Secrets 中新增。")
159
-
160
- llm = ChatGoogleGenerativeAI(model='gemini-2.5-flash', google_api_key=API_KEY)
161
- """
162
- # =============================================
163
-
164
-
165
-
166
- logo_path = os.path.join(BASE_DIR, "mega.png")
167
- with open(logo_path, "rb") as f:
168
- logo_base64 = base64.b64encode(f.read()).decode("utf-8")
169
 
170
  with gr.Blocks(
171
- theme="Taithrah/Minimal",
172
  css="""
173
- /* 固定 logo 在左上角 */
174
  #logo-top {
175
  position: fixed;
176
  top: 12px;
@@ -189,14 +186,14 @@ with gr.Blocks(
189
  """
190
  ) as demo:
191
 
192
- # 插入 logo
193
- gr.HTML(f"""
194
- <div id="logo-top">
195
- <img src="data:image/png;base64,{logo_base64}" alt="logo">
196
- </div>
197
- """)
198
 
199
- gr.Markdown("<h1 style='text-align:center'>👨‍💼 我是小智,您的金融好幫手🫰</h1>")
200
 
201
  with gr.Row():
202
  with gr.Column(scale=4):
@@ -226,7 +223,6 @@ with gr.Blocks(
226
  for label, q in btns:
227
  gr.Button(label).click(lambda h, q=q: handle_input(q, h), [chatbox], [chatbox, user_input])
228
 
229
- # ✅ 清除記憶按鈕
230
  def clear_memory():
231
  memory.clear()
232
  return [], gr.update(value="", placeholder="請輸入問題...")
 
6
  from langchain_google_genai import ChatGoogleGenerativeAI
7
  import chromadb
8
  import gradio as gr
9
+
10
+ # === 🔧 LangChain 版本相容導入 ===
11
+ try:
12
+ from langchain.chains import LLMChain
13
+ except ImportError:
14
+ from langchain_community.chains import LLMChain
15
 
16
  try:
 
17
  from langchain.memory import ConversationBufferMemory
18
  except ImportError:
 
19
  from langchain_community.memory import ConversationBufferMemory
20
 
21
+ from langchain.prompts import ChatPromptTemplate
22
+
23
+
24
  # =============================================
25
  # 1️⃣ 自訂 LM Studio Embedding 類別
26
  # =============================================
 
37
  res = self.client.embeddings.create(input=texts, model=self.model_name)
38
  return [x.embedding for x in res.data]
39
 
40
+
41
  # =============================================
42
+ # 2️⃣ 載入 QA 檔案並分類(相對路徑)
43
  # =============================================
 
 
 
44
  BASE_DIR = os.path.dirname(os.path.abspath(__file__))
45
  path = os.path.join(BASE_DIR, "QA_v2.txt")
46
 
47
+ if not os.path.exists(path):
48
+ raise FileNotFoundError(f"❌ 找不到 QA 檔案:{path}")
49
+
50
  with open(path, "r", encoding="utf-8") as f:
51
  text = f.read()
52
 
 
66
  for k, v in qa_docs.items():
67
  print(f" {k}:{len(v)} 筆")
68
 
69
+
70
  # =============================================
71
  # 3️⃣ 建立三個獨立向量資料庫
72
  # =============================================
 
91
 
92
  print("✅ 各類別向量資料庫建立完成")
93
 
94
+
95
  # =============================================
96
+ # 4️⃣ 初始化 Gemini LLM(讀取 Hugging Face Secret)
97
  # =============================================
98
+ API_KEY = os.getenv("GOOGLE_API_KEY")
99
+ if not API_KEY:
100
+ raise ValueError("⚠️ 未設定 GOOGLE_API_KEY,請在 Hugging Face Secrets 中新增。")
101
+
102
  llm = ChatGoogleGenerativeAI(model='gemini-2.5-flash', google_api_key=API_KEY)
103
 
104
  memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
105
 
 
106
  prompt = ChatPromptTemplate.from_messages([
107
  ("system", "你是一位金融客服人員,請根據下列公司規章內容回答使用者問題。若內容不足,也請根據既有資訊給出合理說明,並建議洽營業員了解詳情。"),
108
  ("human", "{input}")
 
114
  memory=memory
115
  )
116
 
117
+
118
  # =============================================
119
  # 5️⃣ 自動分類 + 對話主邏輯
120
  # =============================================
 
128
  else:
129
  return "證券"
130
 
131
+
132
  def chat_fn(message, history):
133
  print(f"[DEBUG] 問題:{message}")
134
 
 
144
  docs = vectordb.similarity_search(message, k=2)
145
  context = "\n\n".join([d.page_content for d in docs]) if docs else "目前查無相關內容。"
146
 
 
147
  full_input = f"公司規章內容如下:\n{context}\n\n使用者問題:{message}"
148
 
149
  try:
 
154
 
155
  return reply or "請洽營業員"
156
 
157
+
158
  # =============================================
159
  # 6️⃣ Gradio 介面 + 左上角 logo
160
  # =============================================
161
+ logo_path = os.path.join(BASE_DIR, "mega.png")
162
+ if os.path.exists(logo_path):
163
+ with open(logo_path, "rb") as f:
164
+ logo_base64 = base64.b64encode(f.read()).decode("utf-8")
165
+ else:
166
+ logo_base64 = ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
 
168
  with gr.Blocks(
169
+ theme="soft",
170
  css="""
 
171
  #logo-top {
172
  position: fixed;
173
  top: 12px;
 
186
  """
187
  ) as demo:
188
 
189
+ if logo_base64:
190
+ gr.HTML(f"""
191
+ <div id="logo-top">
192
+ <img src="data:image/png;base64,{logo_base64}" alt="logo">
193
+ </div>
194
+ """)
195
 
196
+ gr.Markdown("<h1 style='text-align:center'>👨‍💼 我是小智,您的金融好幫手 🫰</h1>")
197
 
198
  with gr.Row():
199
  with gr.Column(scale=4):
 
223
  for label, q in btns:
224
  gr.Button(label).click(lambda h, q=q: handle_input(q, h), [chatbox], [chatbox, user_input])
225
 
 
226
  def clear_memory():
227
  memory.clear()
228
  return [], gr.update(value="", placeholder="請輸入問題...")