Song commited on
Commit
f2f2687
·
1 Parent(s): b3c381d

long return

Browse files
Files changed (1) hide show
  1. app.py +16 -14
app.py CHANGED
@@ -2,6 +2,7 @@
2
  # -*- coding: utf-8 -*-
3
  # ---------- 環境與快取設定 (應置於最前) ----------
4
  import os
 
5
  from typing import List, Dict
6
  from contextlib import asynccontextmanager
7
  from fastapi import FastAPI, Request, HTTPException
@@ -18,7 +19,7 @@ from linebot.v3.messaging import (
18
  from linebot.v3.webhook import WebhookParser
19
  from linebot.v3.exceptions import InvalidSignatureError
20
 
21
- from openai import OpenAI
22
  from tavily import TavilyClient
23
  from sentence_transformers import SentenceTransformer, util
24
  from tenacity import retry, stop_after_attempt, wait_exponential
@@ -41,7 +42,7 @@ LLM_API_CONFIG = {
41
 
42
  LLM_MODEL_CONFIG = {
43
  "model": os.getenv("LLM_MODEL", "gemini-3-pro"),
44
- "max_tokens": int(os.getenv("MAX_TOKENS", 2000)),
45
  "temperature": float(os.getenv("TEMPERATURE", 0.3)),
46
  "seed": int(os.getenv("LLM_SEED", 42)),
47
  }
@@ -125,7 +126,7 @@ def perform_web_search(query: str, max_results: int = 5) -> str:
125
  class ChatPipeline:
126
  def __init__(self):
127
  self.embedder = SentenceTransformer('all-MiniLM-L6-v2')
128
- self.llm_client = OpenAI(
129
  api_key=LLM_API_CONFIG["api_key"],
130
  base_url=LLM_API_CONFIG["base_url"],
131
  default_headers={
@@ -135,22 +136,22 @@ class ChatPipeline:
135
  )
136
 
137
  @retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10))
138
- def _llm_call(self, messages: List[Dict[str, str]], max_tokens: int = None) -> str:
139
  token_est = estimate_tokens(messages)
140
  if token_est > 50000:
141
  raise ValueError("輸入過長")
142
 
143
- response = self.llm_client.chat.completions.create(
144
  model=LLM_MODEL_CONFIG["model"],
145
  messages=messages,
146
  max_tokens=max_tokens or LLM_MODEL_CONFIG["max_tokens"],
147
  temperature=LLM_MODEL_CONFIG["temperature"],
148
  seed=LLM_MODEL_CONFIG["seed"],
149
- timeout=30.0,
150
  )
151
  return response.choices[0].message.content or ""
152
 
153
- def _needs_search(self, user_text: str, history: List[Dict[str, str]]) -> bool:
154
  """輕量判斷是否需要網路搜尋"""
155
  router_prompt = [
156
  {"role": "system", "content": "你只需要判斷用戶問題是否需要最新的網路資訊來回答。"
@@ -161,7 +162,7 @@ class ChatPipeline:
161
  {"role": "user", "content": user_text}
162
  ]
163
  try:
164
- decision = self._llm_call(router_prompt, max_tokens=10).strip().lower()
165
  print(f"搜尋需求判斷:{decision}(問題:{user_text})")
166
  return decision == "yes"
167
  except Exception as e:
@@ -178,7 +179,7 @@ class ChatPipeline:
178
  conversations.pop(user_id, None)
179
  pending_chunks.pop(user_id, None)
180
 
181
- def answer_question(self, user_id: str, user_text: str) -> str:
182
  if user_text.strip().lower() == "/clear":
183
  self.clear_conversation_history(user_id)
184
  return "對話紀錄已清除!現在開始新的對話。"
@@ -186,11 +187,12 @@ class ChatPipeline:
186
  history = self.get_conversation_history(user_id)
187
 
188
  # ---- 新增:判斷是否需要搜尋 ----
189
- needs_search = self._needs_search(user_text, history)
190
 
191
  search_results = None
192
  if needs_search:
193
- search_results = perform_web_search(user_text)
 
194
 
195
  # ---- 建構最終 prompt ----
196
  messages = [{"role": "system", "content": SYSTEM_PROMPT}]
@@ -200,7 +202,7 @@ class ChatPipeline:
200
  if search_results and "沒有找到" not in search_results and "錯誤" not in search_results:
201
  messages.append({"role": "system", "content": f"網路搜尋結果(僅在高度相關時使用):{search_results}"})
202
 
203
- response = self._llm_call(messages)
204
  response = response.replace('*', '')
205
 
206
  # 更新歷史(包含最終回應)
@@ -215,7 +217,7 @@ class ChatPipeline:
215
  {"role": "system", "content": "請將以下內容生成一個簡潔但完整的中文摘要,保留關鍵事實和細節,長度控制在2000字元內。"},
216
  {"role": "user", "content": response}
217
  ]
218
- summary = self._llm_call(summary_prompt).replace('*', '')
219
  return summary + "\n\n(完整回應過長,已提供摘要。如需細節,請分次詢問或回覆「繼續」)"
220
 
221
  return response
@@ -277,7 +279,7 @@ async def line_webhook(request: Request):
277
  continue
278
 
279
  # 正常回應
280
- ai_response = chat_pipeline.answer_question(user_id, user_text)
281
  chunks = split_text_for_line(ai_response)
282
 
283
  if len(chunks) <= 5:
 
2
  # -*- coding: utf-8 -*-
3
  # ---------- 環境與快取設定 (應置於最前) ----------
4
  import os
5
+ import asyncio
6
  from typing import List, Dict
7
  from contextlib import asynccontextmanager
8
  from fastapi import FastAPI, Request, HTTPException
 
19
  from linebot.v3.webhook import WebhookParser
20
  from linebot.v3.exceptions import InvalidSignatureError
21
 
22
+ from openai import AsyncOpenAI
23
  from tavily import TavilyClient
24
  from sentence_transformers import SentenceTransformer, util
25
  from tenacity import retry, stop_after_attempt, wait_exponential
 
42
 
43
  LLM_MODEL_CONFIG = {
44
  "model": os.getenv("LLM_MODEL", "gemini-3-pro"),
45
+ "max_tokens": int(os.getenv("MAX_TOKENS", 4000)),
46
  "temperature": float(os.getenv("TEMPERATURE", 0.3)),
47
  "seed": int(os.getenv("LLM_SEED", 42)),
48
  }
 
126
  class ChatPipeline:
127
  def __init__(self):
128
  self.embedder = SentenceTransformer('all-MiniLM-L6-v2')
129
+ self.llm_client = AsyncOpenAI(
130
  api_key=LLM_API_CONFIG["api_key"],
131
  base_url=LLM_API_CONFIG["base_url"],
132
  default_headers={
 
136
  )
137
 
138
  @retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10))
139
+ async def _llm_call(self, messages: List[Dict[str, str]], max_tokens: int = None) -> str:
140
  token_est = estimate_tokens(messages)
141
  if token_est > 50000:
142
  raise ValueError("輸入過長")
143
 
144
+ response = await self.llm_client.chat.completions.create(
145
  model=LLM_MODEL_CONFIG["model"],
146
  messages=messages,
147
  max_tokens=max_tokens or LLM_MODEL_CONFIG["max_tokens"],
148
  temperature=LLM_MODEL_CONFIG["temperature"],
149
  seed=LLM_MODEL_CONFIG["seed"],
150
+ timeout=120.0,
151
  )
152
  return response.choices[0].message.content or ""
153
 
154
+ async def _needs_search(self, user_text: str, history: List[Dict[str, str]]) -> bool:
155
  """輕量判斷是否需要網路搜尋"""
156
  router_prompt = [
157
  {"role": "system", "content": "你只需要判斷用戶問題是否需要最新的網路資訊來回答。"
 
162
  {"role": "user", "content": user_text}
163
  ]
164
  try:
165
+ decision = (await self._llm_call(router_prompt, max_tokens=10)).strip().lower()
166
  print(f"搜尋需求判斷:{decision}(問題:{user_text})")
167
  return decision == "yes"
168
  except Exception as e:
 
179
  conversations.pop(user_id, None)
180
  pending_chunks.pop(user_id, None)
181
 
182
+ async def answer_question(self, user_id: str, user_text: str) -> str:
183
  if user_text.strip().lower() == "/clear":
184
  self.clear_conversation_history(user_id)
185
  return "對話紀錄已清除!現在開始新的對話。"
 
187
  history = self.get_conversation_history(user_id)
188
 
189
  # ---- 新增:判斷是否需要搜尋 ----
190
+ needs_search = await self._needs_search(user_text, history)
191
 
192
  search_results = None
193
  if needs_search:
194
+ # search is sync, but fast. Consider wrapping in to_thread if blocking is an issue.
195
+ search_results = await asyncio.to_thread(perform_web_search, user_text)
196
 
197
  # ---- 建構最終 prompt ----
198
  messages = [{"role": "system", "content": SYSTEM_PROMPT}]
 
202
  if search_results and "沒有找到" not in search_results and "錯誤" not in search_results:
203
  messages.append({"role": "system", "content": f"網路搜尋結果(僅在高度相關時使用):{search_results}"})
204
 
205
+ response = await self._llm_call(messages)
206
  response = response.replace('*', '')
207
 
208
  # 更新歷史(包含最終回應)
 
217
  {"role": "system", "content": "請將以下內容生成一個簡潔但完整的中文摘要,保留關鍵事實和細節,長度控制在2000字元內。"},
218
  {"role": "user", "content": response}
219
  ]
220
+ summary = (await self._llm_call(summary_prompt)).replace('*', '')
221
  return summary + "\n\n(完整回應過長,已提供摘要。如需細節,請分次詢問或回覆「繼續」)"
222
 
223
  return response
 
279
  continue
280
 
281
  # 正常回應
282
+ ai_response = await chat_pipeline.answer_question(user_id, user_text)
283
  chunks = split_text_for_line(ai_response)
284
 
285
  if len(chunks) <= 5: