adamtobegreat commited on
Commit
2c81513
·
verified ·
1 Parent(s): 8987e9e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -61
app.py CHANGED
@@ -1,41 +1,22 @@
1
- import os, re, requests, base64
2
  from langchain_core.documents import Document
3
  from langchain_chroma import Chroma
4
  from openai import OpenAI
5
  from langchain.embeddings.base import Embeddings
6
  from langchain_google_genai import ChatGoogleGenerativeAI
7
- import chromadb
8
  import gradio as gr
9
-
10
- # === 記憶模組相容多版本 ===
11
- try:
12
- from langchain_memory import ConversationBufferMemory
13
- except ImportError:
14
- try:
15
- from langchain.memory import ConversationBufferMemory
16
- except ImportError:
17
- from langchain_community.memory import ConversationBufferMemory
18
-
19
 
20
  # =============================================
21
- # 1️⃣ 自訂 LM Studio Embedding 類別
22
  # =============================================
23
- class LmStudioEmbeddings(Embeddings):
24
- def __init__(self, model_name, url):
25
- self.model_name = model_name
26
- self.client = OpenAI(base_url=url, api_key="lm-studio")
27
-
28
- def embed_query(self, text: str):
29
- res = self.client.embeddings.create(input=text, model=self.model_name)
30
- return res.data[0].embedding
31
-
32
- def embed_documents(self, texts: list[str]):
33
- res = self.client.embeddings.create(input=texts, model=self.model_name)
34
- return [x.embedding for x in res.data]
35
 
 
36
 
37
  # =============================================
38
- # 2️⃣ 載入 QA 檔案並分類(相對路徑)
39
  # =============================================
40
  BASE_DIR = os.path.dirname(os.path.abspath(__file__))
41
  path = os.path.join(BASE_DIR, "QA_v2.txt")
@@ -58,38 +39,17 @@ for qa in qas:
58
  elif "複委託" in qa:
59
  qa_docs["複委託"].append(Document(page_content=qa.strip(), metadata={"source": path}))
60
 
61
- print("✅ 已成功讀取 QA 並完成分類:")
62
- for k, v in qa_docs.items():
63
- print(f" {k}:{len(v)} 筆")
64
-
65
 
66
  # =============================================
67
- # 3️⃣ 建立向量資料庫
68
  # =============================================
69
- embedding = LmStudioEmbeddings(
70
- model_name="text-embedding-bge-large-zh-v1.5",
71
- url="http://127.0.0.1:1234/v1"
72
- )
73
-
74
- client = chromadb.PersistentClient(path="./chroma_db")
75
- collection_names = {"證券": "stocks", "期貨": "futures", "複委託": "overseas"}
76
-
77
  vectordbs = {}
78
- for cat, docs in qa_docs.items():
79
- eng_name = collection_names[cat]
80
- vectordbs[cat] = Chroma(
81
- client=client,
82
- collection_name=eng_name,
83
- embedding_function=embedding
84
- )
85
- if len(vectordbs[cat].get()["documents"]) == 0:
86
- vectordbs[cat].add_documents(docs)
87
-
88
- print("✅ 各類別向量資料庫建立完成")
89
-
90
 
91
  # =============================================
92
- # 4️⃣ 初始化 Gemini LLM(從 Secret 讀取)
93
  # =============================================
94
  API_KEY = os.getenv("GOOGLE_API_KEY")
95
  if not API_KEY:
@@ -98,9 +58,8 @@ if not API_KEY:
98
  llm = ChatGoogleGenerativeAI(model='gemini-2.5-flash', google_api_key=API_KEY)
99
  memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
100
 
101
-
102
  # =============================================
103
- # 5️⃣ 對話主邏輯
104
  # =============================================
105
  def auto_detect_category(text):
106
  if any(k in text for k in ["股票", "證券", "開戶", "下單", "交割", "現股"]):
@@ -112,13 +71,7 @@ def auto_detect_category(text):
112
  else:
113
  return "證券"
114
 
115
-
116
  def chat_fn(message, history):
117
- print(f"[DEBUG] 問題:{message}")
118
-
119
- if "午餐吃什麼" in message:
120
- return "還在盤中交易無法離開,還是我們約下午茶如何?"
121
-
122
  category = auto_detect_category(message)
123
  vectordb = vectordbs.get(category)
124
  if not vectordb:
@@ -143,7 +96,6 @@ def chat_fn(message, history):
143
 
144
  return reply or "請洽營業員"
145
 
146
-
147
  # =============================================
148
  # 6️⃣ Gradio 介面
149
  # =============================================
 
1
+ import os, re, base64
2
  from langchain_core.documents import Document
3
  from langchain_chroma import Chroma
4
  from openai import OpenAI
5
  from langchain.embeddings.base import Embeddings
6
  from langchain_google_genai import ChatGoogleGenerativeAI
7
+ from langchain_community.vectorstores import FAISS
8
  import gradio as gr
9
+ from langchain.memory import ConversationBufferMemory
 
 
 
 
 
 
 
 
 
10
 
11
  # =============================================
12
+ # 1️⃣ 內建 Embedding:使用 Gemini embedding API
13
  # =============================================
14
+ from langchain_google_genai import GoogleGenerativeAIEmbeddings
 
 
 
 
 
 
 
 
 
 
 
15
 
16
+ embedding = GoogleGenerativeAIEmbeddings(model="models/embedding-001", google_api_key=os.getenv("GOOGLE_API_KEY"))
17
 
18
  # =============================================
19
+ # 2️⃣ 載入 QA 檔案並分類
20
  # =============================================
21
  BASE_DIR = os.path.dirname(os.path.abspath(__file__))
22
  path = os.path.join(BASE_DIR, "QA_v2.txt")
 
39
  elif "複委託" in qa:
40
  qa_docs["複委託"].append(Document(page_content=qa.strip(), metadata={"source": path}))
41
 
42
+ print("✅ 已成功讀取 QA 並完成分類:", {k: len(v) for k, v in qa_docs.items()})
 
 
 
43
 
44
  # =============================================
45
+ # 3️⃣ 建立向量資料庫(使用 FAISS,記憶體型)
46
  # =============================================
 
 
 
 
 
 
 
 
47
  vectordbs = {}
48
+ for k, docs in qa_docs.items():
49
+ vectordbs[k] = FAISS.from_documents(docs, embedding)
 
 
 
 
 
 
 
 
 
 
50
 
51
  # =============================================
52
+ # 4️⃣ 初始化 Gemini LLM
53
  # =============================================
54
  API_KEY = os.getenv("GOOGLE_API_KEY")
55
  if not API_KEY:
 
58
  llm = ChatGoogleGenerativeAI(model='gemini-2.5-flash', google_api_key=API_KEY)
59
  memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
60
 
 
61
  # =============================================
62
+ # 5️⃣ 對話邏輯
63
  # =============================================
64
  def auto_detect_category(text):
65
  if any(k in text for k in ["股票", "證券", "開戶", "下單", "交割", "現股"]):
 
71
  else:
72
  return "證券"
73
 
 
74
  def chat_fn(message, history):
 
 
 
 
 
75
  category = auto_detect_category(message)
76
  vectordb = vectordbs.get(category)
77
  if not vectordb:
 
96
 
97
  return reply or "請洽營業員"
98
 
 
99
  # =============================================
100
  # 6️⃣ Gradio 介面
101
  # =============================================