File size: 2,559 Bytes
5367a84
 
 
 
 
 
 
 
a9e9c14
a3d18a3
5367a84
a3d18a3
3d360d7
5367a84
a3d18a3
5367a84
a3d18a3
5367a84
3d360d7
5367a84
 
 
 
3d360d7
5367a84
 
 
3d360d7
a3d18a3
5367a84
 
 
a3d18a3
5367a84
 
 
a3d18a3
a9e9c14
 
a3d18a3
5367a84
 
 
 
3d360d7
5367a84
a9e9c14
5367a84
3d360d7
a3d18a3
5367a84
 
a3d18a3
5367a84
 
 
 
a9e9c14
5367a84
a9e9c14
 
5367a84
a9e9c14
5367a84
a9e9c14
 
 
 
5367a84
3d360d7
a9e9c14
 
 
3d360d7
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
from fastapi import FastAPI, Request
from sentence_transformers import SentenceTransformer, util
import json
import torch
import os

app = FastAPI()

# 阈值 (低于这个分数强制返回 neutral)
THRESHOLD = 0.35

print("正在加载 BGE-Large-ZH-v1.5...")
# 这里会自动下载模型,如果日志卡在这里请耐心等待
model = SentenceTransformer('BAAI/bge-large-zh-v1.5')
print("模型加载完成")

def load_data():
    if not os.path.exists('emoji_labels.json'):
        print("警告: 找不到 emoji_labels.json")
        return [], None
    with open('emoji_labels.json', 'r', encoding='utf-8') as f:
        data = json.load(f)
    texts = [item['text'] for item in data]
    # 预计算向量
    embeddings = model.encode(texts, normalize_embeddings=True, convert_to_tensor=True)
    return data, embeddings

# 初始化数据
emoji_data, emoji_embeddings = load_data()

@app.get("/")
def home():
    return "Kouri 5-Emotion System Ready"

@app.post("/match")
async def match_emoji(request: Request):
    """
    输入: {"text": "我想吃汉堡"}
    输出: {"label": "happy", "score": 0.85}
    """
    try:
        body = await request.json()
        user_text = body.get("text", "")
        
        # 兜底:空输入或者数据没加载好,返回 neutral
        if not user_text or emoji_embeddings is None:
            return {"label": "neutral", "score": 0.0}

        # 构造查询 (BGE模型建议加前缀)
        query_text = "为这个句子分类情感:" + user_text
        query_emb = model.encode(query_text, normalize_embeddings=True, convert_to_tensor=True)
        
        # 计算相似度
        scores = util.cos_sim(query_emb, emoji_embeddings)[0]
        best_score = float(torch.max(scores))
        best_idx = int(torch.argmax(scores))
        
        matched_item = emoji_data[best_idx]
        
        # 打印日志
        print(f"输入: {user_text} | 匹配: {matched_item['label']} | 分数: {best_score:.4f}")

        # 逻辑判断
        if best_score > THRESHOLD:
            return {
                "label": matched_item['label'], # 例如 "happy"
                "score": best_score             # 例如 0.8512
            }
        else:
            # 分数太低,返回 neutral
            return {
                "label": "neutral", 
                "score": best_score
            }
            
    except Exception as e:
        print(f"Error: {e}")
        # 发生任何报错都返回 neutral,保证程序不崩
        return {"label": "neutral", "score": 0.0}