SarahXia0405 commited on
Commit
26220ae
·
verified ·
1 Parent(s): 337c831

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +77 -14
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import os
2
  import re
 
3
  from typing import List, Dict, Tuple, Optional
4
 
5
  import gradio as gr
@@ -15,6 +16,7 @@ if not OPENAI_API_KEY:
15
 
16
  client = OpenAI(api_key=OPENAI_API_KEY)
17
  DEFAULT_MODEL = "gpt-4.1-mini"
 
18
 
19
  # ---------- 默认 GenAI 课程大纲 ----------
20
  DEFAULT_COURSE_TOPICS = [
@@ -201,7 +203,6 @@ def _normalize_text(text: str) -> str:
201
  text = text.lower().strip()
202
  # 去掉标点符号,只保留字母数字和空格
203
  text = re.sub(r"[^\w\s]", " ", text)
204
- # 合并多余空格
205
  text = re.sub(r"\s+", " ", text)
206
  return text
207
 
@@ -214,27 +215,58 @@ def _jaccard_similarity(a: str, b: str) -> float:
214
  return len(tokens_a & tokens_b) / len(tokens_a | tokens_b)
215
 
216
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
217
  def find_similar_past_question(
218
  message: str,
219
  history: List[Tuple[str, str]],
220
- similarity_threshold: float = 0.8,
 
221
  max_turns_to_check: int = 6,
222
  ) -> Optional[Tuple[str, str, float]]:
223
  """
224
  在最近若干轮历史对话中查找与当前问题相似的既往问题。
225
 
 
 
 
 
226
  返回:
227
- (past_question, past_answer, similarity) 或 None
228
  """
 
229
  norm_msg = _normalize_text(message)
230
  if not norm_msg:
231
  return None
232
 
233
- best_sim = 0.0
234
- best_pair: Optional[Tuple[str, str]] = None
235
  checked = 0
236
 
237
- # 从最近一轮往前看
238
  for user_q, assistant_a in reversed(history):
239
  checked += 1
240
  if checked > max_turns_to_check:
@@ -244,17 +276,49 @@ def find_similar_past_question(
244
  if not norm_hist_q:
245
  continue
246
 
247
- # 完全相同直接返回
248
  if norm_msg == norm_hist_q:
 
249
  return user_q, assistant_a, 1.0
250
 
251
- sim = _jaccard_similarity(norm_msg, norm_hist_q)
252
- if sim > best_sim:
253
- best_sim = sim
254
- best_pair = (user_q, assistant_a)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
255
 
256
- if best_pair and best_sim >= similarity_threshold:
257
- return best_pair[0], best_pair[1], best_sim
258
 
259
  return None
260
 
@@ -704,7 +768,6 @@ with gr.Blocks(title="Clare – Hanbridge AI Teaching Assistant") as demo:
704
  dup = find_similar_past_question(message, chat_history)
705
  if dup is not None:
706
  past_q, past_a, sim = dup
707
- # 直接复用之前回答,并给一个简短提示
708
  prefix_en = (
709
  "I noticed this question is very similar to one you asked earlier, "
710
  "so I'm showing the previous explanation again. "
 
1
  import os
2
  import re
3
+ import math
4
  from typing import List, Dict, Tuple, Optional
5
 
6
  import gradio as gr
 
16
 
17
  client = OpenAI(api_key=OPENAI_API_KEY)
18
  DEFAULT_MODEL = "gpt-4.1-mini"
19
+ EMBEDDING_MODEL = "text-embedding-3-small"
20
 
21
  # ---------- 默认 GenAI 课程大纲 ----------
22
  DEFAULT_COURSE_TOPICS = [
 
203
  text = text.lower().strip()
204
  # 去掉标点符号,只保留字母数字和空格
205
  text = re.sub(r"[^\w\s]", " ", text)
 
206
  text = re.sub(r"\s+", " ", text)
207
  return text
208
 
 
215
  return len(tokens_a & tokens_b) / len(tokens_a | tokens_b)
216
 
217
 
218
+ def cosine_similarity(a: List[float], b: List[float]) -> float:
219
+ if not a or not b or len(a) != len(b):
220
+ return 0.0
221
+ dot = sum(x * y for x, y in zip(a, b))
222
+ norm_a = math.sqrt(sum(x * x for x in a))
223
+ norm_b = math.sqrt(sum(y * y for y in b))
224
+ if norm_a == 0 or norm_b == 0:
225
+ return 0.0
226
+ return dot / (norm_a * norm_b)
227
+
228
+
229
+ def get_embedding(text: str) -> Optional[List[float]]:
230
+ """
231
+ 调用 OpenAI Embedding API,将文本编码为向量。
232
+ """
233
+ try:
234
+ resp = client.embeddings.create(
235
+ model=EMBEDDING_MODEL,
236
+ input=[text],
237
+ )
238
+ return resp.data[0].embedding
239
+ except Exception:
240
+ # 如果 embedding 调用失败,就返回 None,不阻塞主流程
241
+ return None
242
+
243
+
244
  def find_similar_past_question(
245
  message: str,
246
  history: List[Tuple[str, str]],
247
+ jaccard_threshold: float = 0.65,
248
+ embedding_threshold: float = 0.85,
249
  max_turns_to_check: int = 6,
250
  ) -> Optional[Tuple[str, str, float]]:
251
  """
252
  在最近若干轮历史对话中查找与当前问题相似的既往问题。
253
 
254
+ 两级检测:
255
+ 1. 先用 Jaccard 做快速近似匹配(文本几乎一样的情况)
256
+ 2. 再用 OpenAI embedding 做语义相似度检测(改写、同义句)
257
+
258
  返回:
259
+ (past_question, past_answer, similarity_score) 或 None
260
  """
261
+ # ---------- 第一步:Jaccard 快速检测 ----------
262
  norm_msg = _normalize_text(message)
263
  if not norm_msg:
264
  return None
265
 
266
+ best_sim_j = 0.0
267
+ best_pair_j: Optional[Tuple[str, str]] = None
268
  checked = 0
269
 
 
270
  for user_q, assistant_a in reversed(history):
271
  checked += 1
272
  if checked > max_turns_to_check:
 
276
  if not norm_hist_q:
277
  continue
278
 
 
279
  if norm_msg == norm_hist_q:
280
+ # 完全相同,直接视为重复
281
  return user_q, assistant_a, 1.0
282
 
283
+ sim_j = _jaccard_similarity(norm_msg, norm_hist_q)
284
+ if sim_j > best_sim_j:
285
+ best_sim_j = sim_j
286
+ best_pair_j = (user_q, assistant_a)
287
+
288
+ if best_pair_j and best_sim_j >= jaccard_threshold:
289
+ # 词面高度相似,直接视为重复
290
+ return best_pair_j[0], best_pair_j[1], best_sim_j
291
+
292
+ # ---------- 第二步:Embedding 语义相似度 ----------
293
+ # 如果历史太少,就没必要算 embedding
294
+ if not history:
295
+ return None
296
+
297
+ msg_emb = get_embedding(message)
298
+ if msg_emb is None:
299
+ # embedding 调用失败,放弃语义检测
300
+ return None
301
+
302
+ best_sim_e = 0.0
303
+ best_pair_e: Optional[Tuple[str, str]] = None
304
+ checked = 0
305
+
306
+ for user_q, assistant_a in reversed(history):
307
+ checked += 1
308
+ if checked > max_turns_to_check:
309
+ break
310
+
311
+ hist_emb = get_embedding(user_q)
312
+ if hist_emb is None:
313
+ continue
314
+
315
+ sim_e = cosine_similarity(msg_emb, hist_emb)
316
+ if sim_e > best_sim_e:
317
+ best_sim_e = sim_e
318
+ best_pair_e = (user_q, assistant_a)
319
 
320
+ if best_pair_e and best_sim_e >= embedding_threshold:
321
+ return best_pair_e[0], best_pair_e[1], best_sim_e
322
 
323
  return None
324
 
 
768
  dup = find_similar_past_question(message, chat_history)
769
  if dup is not None:
770
  past_q, past_a, sim = dup
 
771
  prefix_en = (
772
  "I noticed this question is very similar to one you asked earlier, "
773
  "so I'm showing the previous explanation again. "