simler commited on
Commit
511d84d
·
verified ·
1 Parent(s): e8e1f91

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +85 -0
  2. emoji_labels.json +10 -0
  3. requirements.txt +5 -0
app.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, Request
2
+ from sentence_transformers import SentenceTransformer, util
3
+ import json
4
+ import torch
5
+ import os
6
+
7
+ app = FastAPI()
8
+
9
+ # ================= 配置区域 =================
10
+ # 设定匹配阈值 (0-1)。
11
+ # 建议 0.3 - 0.4。太高会导致匹配不到,太低会导致乱匹配。
12
+ THRESHOLD = 0.35
13
+
14
+ # 加载轻量级模型 (80MB)
15
+ # 第一次启动时会自动下载
16
+ print("正在加载模型...")
17
+ model = SentenceTransformer('all-MiniLM-L6-v2')
18
+ print("模型加载完成")
19
+
20
+ # ================= 数据预处理 =================
21
+ # 读取 JSON 文件并预计算向量
22
+ def load_and_encode_data():
23
+ if not os.path.exists('emoji_labels.json'):
24
+ print("错误: 找不到 emoji_labels.json")
25
+ return [], None
26
+
27
+ with open('emoji_labels.json', 'r', encoding='utf-8') as f:
28
+ data = json.load(f)
29
+
30
+ # 提取描述文本用于计算
31
+ texts = [item['text'] for item in data]
32
+ # 计算向量并转为 Tensor
33
+ embeddings = model.encode(texts, convert_to_tensor=True)
34
+
35
+ return data, embeddings
36
+
37
+ # 初始化数据
38
+ emoji_data, emoji_embeddings = load_and_encode_data()
39
+
40
+ # ================= API 接口 =================
41
+
42
+ @app.get("/")
43
+ def home():
44
+ return {"status": "Kouri Emotion API is running"}
45
+
46
+ @app.post("/match")
47
+ async def match_emoji(request: Request):
48
+ """
49
+ 接收 {"text": "我想吃汉堡"}
50
+ 返回 {"label": "burger", "score": 0.85}
51
+ """
52
+ try:
53
+ body = await request.json()
54
+ user_text = body.get("text", "")
55
+
56
+ if not user_text or emoji_embeddings is None:
57
+ return {"label": None, "reason": "empty_input_or_data"}
58
+
59
+ # 1. 计算用户输入的向量
60
+ query_emb = model.encode(user_text, convert_to_tensor=True)
61
+
62
+ # 2. 计算与库中所有描述的余弦相似度
63
+ scores = util.cos_sim(query_emb, emoji_embeddings)[0]
64
+
65
+ # 3. 找到得分最高的那个
66
+ best_score = float(torch.max(scores))
67
+ best_idx = int(torch.argmax(scores))
68
+
69
+ # 4. 判断是否超过阈值
70
+ if best_score > THRESHOLD:
71
+ matched_item = emoji_data[best_idx]
72
+ return {
73
+ "label": matched_item['label'],
74
+ "score": best_score,
75
+ "matched_text": matched_item['text'] # 方便调试看它匹配到了哪一条
76
+ }
77
+ else:
78
+ return {
79
+ "label": None,
80
+ "score": best_score,
81
+ "reason": "low_confidence"
82
+ }
83
+
84
+ except Exception as e:
85
+ return {"error": str(e)}
emoji_labels.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {"text": "开心快乐高兴哈哈笑嘻嘻好耶", "label": "happy"},
3
+ {"text": "难过哭泣悲伤痛苦呜呜呜", "label": "sad"},
4
+ {"text": "生气愤怒发火暴躁", "label": "angry"},
5
+ {"text": "爱你喜欢你笔芯么么哒", "label": "love"},
6
+ {"text": "无语发呆不知道说什么", "label": "neutral"},
7
+ {"text": "想吃汉堡炸鸡快餐麦当劳肯德基", "label": "burger"},
8
+ {"text": "好喝的奶茶饮料咖啡", "label": "drink"},
9
+ {"text": "看到小猫咪喵喵叫好可爱", "label": "cat"}
10
+ ]
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ sentence-transformers
2
+ torch
3
+ fastapi
4
+ uvicorn
5
+ pandas