simler commited on
Commit
5367a84
·
verified ·
1 Parent(s): 10cb2e3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +89 -84
app.py CHANGED
@@ -1,85 +1,90 @@
1
- from fastapi import FastAPI, Request
2
- from sentence_transformers import SentenceTransformer, util
3
- import json
4
- import torch
5
- import os
6
-
7
- app = FastAPI()
8
-
9
- # ================= 配置区域 =================
10
- # 设定匹配阈值 (0-1)
11
- # 建议 0.3 - 0.4。太高会导致匹配不到,太低会导致乱匹配。
12
- THRESHOLD = 0.35
13
-
14
- # 加载轻量级模型 (80MB)
15
- # 第一次启动时会自动下载
16
- print("正在加载模型...")
17
- model = SentenceTransformer('all-MiniLM-L6-v2')
18
- print("模型加载完成")
19
-
20
- # ================= 数据预处理 =================
21
- # 读取 JSON 文件并预计算向量
22
- def load_and_encode_data():
23
- if not os.path.exists('emoji_labels.json'):
24
- print("错误: 找不到 emoji_labels.json")
25
- return [], None
26
-
27
- with open('emoji_labels.json', 'r', encoding='utf-8') as f:
28
- data = json.load(f)
29
-
30
- # 提取描述文本用于计算
31
- texts = [item['text'] for item in data]
32
- # 计算向量并转为 Tensor
33
- embeddings = model.encode(texts, convert_to_tensor=True)
34
-
35
- return data, embeddings
36
-
37
- # 初始化数据
38
- emoji_data, emoji_embeddings = load_and_encode_data()
39
-
40
- # ================= API 接口 =================
41
-
42
- @app.get("/")
43
- def home():
44
- return {"status": "Kouri Emotion API is running"}
45
-
46
- @app.post("/match")
47
- async def match_emoji(request: Request):
48
- """
49
- 接收 {"text": "我想吃汉堡"}
50
- 返回 {"label": "burger", "score": 0.85}
51
- """
52
- try:
53
- body = await request.json()
54
- user_text = body.get("text", "")
55
-
56
- if not user_text or emoji_embeddings is None:
57
- return {"label": None, "reason": "empty_input_or_data"}
58
-
59
- # 1. 计算用户输入的向量
60
- query_emb = model.encode(user_text, convert_to_tensor=True)
61
-
62
- # 2. 计算与库中所有描述的余弦相似度
63
- scores = util.cos_sim(query_emb, emoji_embeddings)[0]
64
-
65
- # 3. 找到得分最高的那个
66
- best_score = float(torch.max(scores))
67
- best_idx = int(torch.argmax(scores))
68
-
69
- # 4. 判断是否超过阈值
70
- if best_score > THRESHOLD:
71
- matched_item = emoji_data[best_idx]
72
- return {
73
- "label": matched_item['label'],
74
- "score": best_score,
75
- "matched_text": matched_item['text'] # 方便调试看它匹配到了哪一条
76
- }
77
- else:
78
- return {
79
- "label": None,
80
- "score": best_score,
81
- "reason": "low_confidence"
82
- }
83
-
84
- except Exception as e:
 
 
 
 
 
85
  return {"error": str(e)}
 
1
+ from fastapi import FastAPI, Request
2
+ from sentence_transformers import SentenceTransformer, util
3
+ import json
4
+ import torch
5
+ import os
6
+
7
+ app = FastAPI()
8
+
9
+ # ================= 配置区域 =================
10
+ # 匹配阈值 (建议 0.4 - 0.5)
11
+ # BGE 模型的相似度分布通常在 0.6-1.0 之间,所以阈值要设高一点
12
+ THRESHOLD = 0.45
13
+
14
+ print("正在加载 BGE-Large-ZH-v1.5 (中文最强模型)...")
15
+ # 替换为 BAAI/bge-large-zh-v1.5
16
+ # 第一次启动下载需要几十秒,请耐心等待 Space 状态变绿
17
+ model = SentenceTransformer('BAAI/bge-large-zh-v1.5')
18
+ print("模型加载完成!")
19
+
20
+ # ================= 数据预处理 =================
21
+ def load_and_encode_data():
22
+ if not os.path.exists('emoji_labels.json'):
23
+ print("错误: 找不到 emoji_labels.json")
24
+ return [], None
25
+
26
+ with open('emoji_labels.json', 'r', encoding='utf-8') as f:
27
+ data = json.load(f)
28
+
29
+ texts = [item['text'] for item in data]
30
+
31
+ # BGE 模型建议在查询前加指令,但在这种对称匹配场景下,直接 encode 效果也很好
32
+ # 预先计算库中标签的向量
33
+ embeddings = model.encode(texts, normalize_embeddings=True, convert_to_tensor=True)
34
+
35
+ return data, embeddings
36
+
37
+ # 初始化数据
38
+ emoji_data, emoji_embeddings = load_and_encode_data()
39
+
40
+ # ================= API 接口 =================
41
+
42
+ @app.get("/")
43
+ def home():
44
+ return {"status": "Kouri BGE-Large API is running"}
45
+
46
+ @app.post("/match")
47
+ async def match_emoji(request: Request):
48
+ try:
49
+ body = await request.json()
50
+ user_text = body.get("text", "")
51
+
52
+ if not user_text or emoji_embeddings is None:
53
+ return {"label": None}
54
+
55
+ # BGE 模型的小技巧:给查询文本加个指令前缀,效果会更精准
56
+ # 意思就是告诉模型:“帮我为这句话生成个表示,用来找对应的标签”
57
+ query_instruction = "为这个句子生成表示以用于检索相关标签:"
58
+ query_text = query_instruction + user_text
59
+
60
+ # 1. 计算用户输入的向量
61
+ query_emb = model.encode(query_text, normalize_embeddings=True, convert_to_tensor=True)
62
+
63
+ # 2. 计算相似度
64
+ scores = util.cos_sim(query_emb, emoji_embeddings)[0]
65
+
66
+ # 3. 找到得���最高的
67
+ best_score = float(torch.max(scores))
68
+ best_idx = int(torch.argmax(scores))
69
+
70
+ matched_item = emoji_data[best_idx]
71
+
72
+ # 4. 打印日志方便你在 HF 后台看
73
+ print(f"用户输入: {user_text}")
74
+ print(f"最高匹配: {matched_item['label']} ({matched_item['text']}) - 得分: {best_score:.4f}")
75
+
76
+ if best_score > THRESHOLD:
77
+ return {
78
+ "label": matched_item['label'],
79
+ "score": best_score,
80
+ "matched_text": matched_item['text']
81
+ }
82
+ else:
83
+ return {
84
+ "label": None,
85
+ "score": best_score,
86
+ "reason": "low_confidence"
87
+ }
88
+
89
+ except Exception as e:
90
  return {"error": str(e)}