simler commited on
Commit
a3d18a3
·
verified ·
1 Parent(s): db8d7e5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -47
app.py CHANGED
@@ -6,85 +6,66 @@ import os
6
 
7
  app = FastAPI()
8
 
9
- # ================= 配置区域 =================
10
- # 匹配阈值 (建议 0.4 - 0.5)
11
- # BGE 模型的相似度分布通常在 0.6-1.0 之间,所以阈值要设高一点
12
- THRESHOLD = 0.38
13
 
14
- print("正在加载 BGE-Large-ZH-v1.5 (中文最强模型)...")
15
- # 替换为 BAAI/bge-large-zh-v1.5
16
- # 第一次启动下载需要几十秒,请耐心等待 Space 状态变绿
17
  model = SentenceTransformer('BAAI/bge-large-zh-v1.5')
18
- print("模型加载完成")
19
 
20
- # ================= 数据预处理 =================
21
- def load_and_encode_data():
22
  if not os.path.exists('emoji_labels.json'):
23
- print("错误: 找不到 emoji_labels.json")
24
  return [], None
25
-
26
  with open('emoji_labels.json', 'r', encoding='utf-8') as f:
27
  data = json.load(f)
28
-
29
  texts = [item['text'] for item in data]
30
-
31
- # BGE 模型建议在查询前加指令,但在这种对称匹配场景下,直接 encode 效果也很好
32
- # 预先计算库中标签的向量
33
  embeddings = model.encode(texts, normalize_embeddings=True, convert_to_tensor=True)
34
-
35
  return data, embeddings
36
 
37
- # 初始化数据
38
- emoji_data, emoji_embeddings = load_and_encode_data()
39
-
40
- # ================= API 接口 =================
41
 
42
  @app.get("/")
43
  def home():
44
- return {"status": "Kouri BGE-Large API is running"}
45
 
46
  @app.post("/match")
47
  async def match_emoji(request: Request):
 
 
 
 
48
  try:
49
  body = await request.json()
50
  user_text = body.get("text", "")
51
 
 
52
  if not user_text or emoji_embeddings is None:
53
- return {"label": None}
54
 
55
- # BGE 模型的小技巧:给查询文本加个指令前缀,效果会更精准
56
- # 意思就是告诉模型:“帮我为这句话生成表示,用来找对应的标签”
57
- query_instruction = "为这个句子生成表示以用于检索相关标签:"
58
- query_text = query_instruction + user_text
59
-
60
- # 1. 计算用户输入的向量
61
  query_emb = model.encode(query_text, normalize_embeddings=True, convert_to_tensor=True)
62
 
63
- # 2. 计算相似度
64
  scores = util.cos_sim(query_emb, emoji_embeddings)[0]
65
-
66
- # 3. 找到得分最高的
67
  best_score = float(torch.max(scores))
68
  best_idx = int(torch.argmax(scores))
69
 
70
- matched_item = emoji_data[best_idx]
 
71
 
72
- # 4. 打印日志方便你在 HF 后台看
73
- print(f"用户输入: {user_text}")
74
- print(f"最高匹配: {matched_item['label']} ({matched_item['text']}) - 得分: {best_score:.4f}")
75
 
76
  if best_score > THRESHOLD:
77
- return {
78
- "label": matched_item['label'],
79
- "score": best_score,
80
- "matched_text": matched_item['text']
81
- }
82
  else:
83
- return {
84
- "label": None,
85
- "score": best_score,
86
- "reason": "low_confidence"
87
- }
88
 
89
  except Exception as e:
90
- return {"error": str(e)}
 
 
6
 
7
  app = FastAPI()
8
 
9
+ # 阈值设定
10
+ # 如果用户说的话跟5种情绪都不沾边,就返回 neutral
11
+ THRESHOLD = 0.35
 
12
 
13
+ print("正在加载 BGE-Large-ZH-v1.5...")
 
 
14
  model = SentenceTransformer('BAAI/bge-large-zh-v1.5')
15
+ print("模型加载完成")
16
 
17
+ # 加载数据
18
+ def load_data():
19
  if not os.path.exists('emoji_labels.json'):
 
20
  return [], None
 
21
  with open('emoji_labels.json', 'r', encoding='utf-8') as f:
22
  data = json.load(f)
 
23
  texts = [item['text'] for item in data]
24
+ # 预计算向量
 
 
25
  embeddings = model.encode(texts, normalize_embeddings=True, convert_to_tensor=True)
 
26
  return data, embeddings
27
 
28
+ emoji_data, emoji_embeddings = load_data()
 
 
 
29
 
30
  @app.get("/")
31
  def home():
32
+ return "Kouri 5-Emotion System Ready"
33
 
34
  @app.post("/match")
35
  async def match_emoji(request: Request):
36
+ """
37
+ 不管输入什么,只返回 5 种标签之一。
38
+ 格式示例: {"tag": "[happy]"}
39
+ """
40
  try:
41
  body = await request.json()
42
  user_text = body.get("text", "")
43
 
44
+ # 兜底:空输入返回中立
45
  if not user_text or emoji_embeddings is None:
46
+ return {"tag": "[neutral]"}
47
 
48
+ # 构造查询指令
49
+ query_text = "为这个句子分类情感:" + user_text
 
 
 
 
50
  query_emb = model.encode(query_text, normalize_embeddings=True, convert_to_tensor=True)
51
 
52
+ # 计算相似度
53
  scores = util.cos_sim(query_emb, emoji_embeddings)[0]
 
 
54
  best_score = float(torch.max(scores))
55
  best_idx = int(torch.argmax(scores))
56
 
57
+ # 获取标签 (例如 "[happy]")
58
+ matched_label = emoji_data[best_idx]['label']
59
 
60
+ # 打印日志方便调试
61
+ print(f"输入: {user_text} | 匹配: {matched_label} | 分数: {best_score:.4f}")
 
62
 
63
  if best_score > THRESHOLD:
64
+ return {"tag": matched_label}
 
 
 
 
65
  else:
66
+ # 没匹配上也返回中立,保证稳定性
67
+ return {"tag": "[neutral]"}
 
 
 
68
 
69
  except Exception as e:
70
+ print(f"Error: {e}")
71
+ return {"tag": "[neutral]"}