cwadayi commited on
Commit
84b321a
·
verified ·
1 Parent(s): f1439e9

Update ai_service.py

Browse files
Files changed (1) hide show
  1. ai_service.py +103 -46
ai_service.py CHANGED
@@ -1,57 +1,114 @@
1
- # ai_service.py
2
- import torch
3
- from transformers import T5Tokenizer, T5ForConditionalGeneration
4
- from config import LLM_MODEL, LLM_MAX_NEW_TOKENS, LLM_TEMPERATURE, LLM_TOP_K
 
5
 
6
- _LLM = {"loaded": False, "ok": False, "err": None, "model": None, "tokenizer": None, "device": "cpu"}
 
7
 
8
- def _ensure_llm():
9
- """在首次使用時載入 Flan-T5 模型與分詞器。"""
10
- if _LLM["loaded"]:
11
- return _LLM["ok"], _LLM["err"]
12
- _LLM["loaded"] = True
13
 
 
 
 
 
 
 
 
 
14
  try:
15
- device = "cuda" if torch.cuda.is_available() else "cpu"
16
-
17
- # 載入 T5 專用的分詞器和模型
18
- tokenizer = T5Tokenizer.from_pretrained(LLM_MODEL)
19
- model = T5ForConditionalGeneration.from_pretrained(LLM_MODEL).to(device)
20
-
21
- _LLM.update({"ok": True, "model": model, "tokenizer": tokenizer, "device": device})
22
- print(f"Flan-T5 model '{LLM_MODEL}' loaded successfully on {device}.")
23
- return True, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  except Exception as e:
25
- _LLM["err"] = f"{e}"
26
- _LLM["ok"] = False
27
- return False, _LLM["err"]
28
 
29
- def generate_ai_text(user_prompt: str) -> str:
30
- """使用已載入的 Flan-T5 模型生成文字回應。"""
31
- ok, err = _ensure_llm()
32
- if not ok:
33
- return f"🤖 AI 模型無法使用。\n詳細錯誤:{err}"
 
 
 
 
 
 
 
 
34
 
35
- tokenizer = _LLM["tokenizer"]
36
- model = _LLM["model"]
37
- device = _LLM["device"]
38
 
39
- # Flan-T5 建立一個通用的問答指令
40
- input_text = f"請用繁體中文回答以下問題: {user_prompt}"
 
 
 
 
 
 
 
 
41
 
 
 
 
 
 
42
  try:
43
- input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(device)
44
-
45
- with torch.no_grad():
46
- outputs = model.generate(
47
- input_ids,
48
- max_new_tokens=LLM_MAX_NEW_TOKENS,
49
- do_sample=True,
50
- temperature=LLM_TEMPERATURE,
51
- top_k=LLM_TOP_K
52
- )
53
-
54
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
55
- return response.strip() or "(AI 沒有產生任何內容)"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  except Exception as e:
57
- return f"AI 產生內容時發生錯誤:{e}"
 
 
1
+ # ai_service.py (Gemini 最終版)
2
+ import json
3
+ from datetime import datetime
4
+ import google.generativeai as genai
5
+ from gradio_client import Client
6
 
7
+ # 從設定檔匯入金鑰和 URL
8
+ from config import GEMINI_API_KEY, MCP_SERVER_URL
9
 
10
+ # --- 1. 設定 Gemini API 金鑰 (一次性設定) ---
11
+ if GEMINI_API_KEY and "YOUR_GEMINI_API_KEY" not in GEMINI_API_KEY:
12
+ genai.configure(api_key=GEMINI_API_KEY)
 
 
13
 
14
+ # --- 2. 工具函式 (用於地震查詢) ---
15
+ def call_mcp_earthquake_search(
16
+ start_date: str,
17
+ end_date: str,
18
+ min_magnitude: float = 4.5,
19
+ max_magnitude: float = 8.0
20
+ ) -> str:
21
+ """根據指定的條件(時間、規模)從遠端伺服器搜尋地震事件。"""
22
  try:
23
+ print(f"--- 正在呼叫遠端地震 MCP 伺服器 (由 Gemini 觸發) ---")
24
+ print(f" 查詢條件: {start_date} 到 {end_date}, 規模 {min_magnitude} 以上")
25
+
26
+ client = Client(src=MCP_SERVER_URL)
27
+ result = client.predict(
28
+ param_0=start_date, param_1="00:00:00",
29
+ param_2=end_date, param_3="23:59:59",
30
+ param_4=21.0, param_5=26.0, # 預設台灣緯度
31
+ param_6=119.0, param_7=123.0, # 預設台灣經度
32
+ param_8=0.0, param_9=100.0,
33
+ param_10=min_magnitude, param_11=max_magnitude,
34
+ api_name="/gradio_fetch_and_plot_data"
35
+ )
36
+ dataframe_dict = result[0]
37
+ data = dataframe_dict.get('data', [])
38
+
39
+ if not data:
40
+ print("--- MCP 伺服器回傳:未找到符合條件的地震 ---")
41
+ return "查詢完成,但未找到任何符合條件的地震資料。"
42
+
43
+ headers = dataframe_dict.get('headers', [])
44
+ formatted_results = [dict(zip(headers, row)) for row in data]
45
+ print(f"--- MCP 伺服器成功回傳 {len(data)} 筆資料 ---")
46
+ return json.dumps(formatted_results, indent=2, ensure_ascii=False)
47
  except Exception as e:
48
+ print(f"呼叫 MCP 伺服器失敗: {e}")
49
+ return f"工具執行失敗,錯誤訊息: {e}"
 
50
 
51
+ # --- 3. Gemini 定義工具 ---
52
+ earthquake_search_tool_declaration = {
53
+ "name": "call_earthquake_search_tool",
54
+ "description": "根據指定的條件(時間、地點、規模等)從台灣中央氣象署的資料庫中搜尋地震事件。預設搜尋台灣周邊地區。",
55
+ "parameters": {
56
+ "type": "OBJECT", "properties": {
57
+ "start_date": {"type": "STRING", "description": "搜尋的開始日期,格式為 'YYYY-MM-DD'。"},
58
+ "end_date": {"type": "STRING", "description": f"搜尋的結束日期,格式為 'YYYY-MM-DD'。預設為今天: {datetime.now().strftime('%Y-%m-%d')}。"},
59
+ "min_magnitude": {"type": "NUMBER", "description": "要搜尋的最小地震規模。預設為 4.5。"},
60
+ "max_magnitude": {"type": "NUMBER", "description": "要搜尋的最大地震規模。預設為 8.0。"},
61
+ }, "required": ["start_date", "end_date", "min_magnitude"]
62
+ }
63
+ }
64
 
65
+ available_tools = {"call_earthquake_search_tool": call_mcp_earthquake_search}
 
 
66
 
67
+ # --- 4. 建立 Gemini 模型 (Singleton 模式) ---
68
+ model = None
69
+ if GEMINI_API_KEY and "YOUR_GEMINI_API_KEY" not in GEMINI_API_KEY:
70
+ try:
71
+ model = genai.GenerativeModel(
72
+ model_name="gemini-1.5-flash",
73
+ tools=[earthquake_search_tool_declaration]
74
+ )
75
+ except Exception as e:
76
+ print(f"建立 Gemini 模型失敗: {e}")
77
 
78
+ # --- 5. 主要的 AI 文字生成函式 ---
79
+ def generate_ai_text(user_prompt: str) -> str:
80
+ """使用 Gemini 模型生成回應,並在需���時觸發工具呼叫。"""
81
+ if not model:
82
+ return "🤖 AI (Gemini) 服務尚未設定 API 金鑰,或金鑰無效。"
83
  try:
84
+ print(f"--- 開始 Gemini 對話,使用者輸入: '{user_prompt}' ---")
85
+ chat = model.start_chat()
86
+ response = chat.send_message(user_prompt)
87
+
88
+ # 檢查模型是否要求呼叫工具
89
+ try:
90
+ function_call = response.candidates[0].content.parts[0].function_call
91
+ except (IndexError, AttributeError):
92
+ function_call = None
93
+
94
+ if not function_call:
95
+ print("--- Gemini 直接回覆文字 ---")
96
+ return response.text
97
+
98
+ # 處理工具呼叫
99
+ print(f"--- Gemini 要求呼叫工具: {function_call.name} ---")
100
+ tool_function = available_tools.get(function_call.name)
101
+ if not tool_function:
102
+ return f"錯誤:模型嘗試呼叫一個不存在的工具 '{function_call.name}'。"
103
+
104
+ tool_result = tool_function(**dict(function_call.args))
105
+ print("--- 將工具結果回傳給 Gemini ---")
106
+ response = chat.send_message(
107
+ genai.Part(function_response={"name": function_call.name, "response": {"result": tool_result}}),
108
+ )
109
+ print("--- Gemini 根據工具結果生成最終回覆 ---")
110
+ return response.text
111
+
112
  except Exception as e:
113
+ print(f"與 Gemini AI 互動時發生錯誤: {e}")
114
+ return f"🤖 AI 服務發生錯誤: {e}"