File size: 2,559 Bytes
5367a84 a9e9c14 a3d18a3 5367a84 a3d18a3 3d360d7 5367a84 a3d18a3 5367a84 a3d18a3 5367a84 3d360d7 5367a84 3d360d7 5367a84 3d360d7 a3d18a3 5367a84 a3d18a3 5367a84 a3d18a3 a9e9c14 a3d18a3 5367a84 3d360d7 5367a84 a9e9c14 5367a84 3d360d7 a3d18a3 5367a84 a3d18a3 5367a84 a9e9c14 5367a84 a9e9c14 5367a84 a9e9c14 5367a84 a9e9c14 5367a84 3d360d7 a9e9c14 3d360d7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 |
from fastapi import FastAPI, Request
from sentence_transformers import SentenceTransformer, util
import json
import torch
import os
app = FastAPI()
# 阈值 (低于这个分数强制返回 neutral)
THRESHOLD = 0.35
print("正在加载 BGE-Large-ZH-v1.5...")
# 这里会自动下载模型,如果日志卡在这里请耐心等待
model = SentenceTransformer('BAAI/bge-large-zh-v1.5')
print("模型加载完成")
def load_data():
if not os.path.exists('emoji_labels.json'):
print("警告: 找不到 emoji_labels.json")
return [], None
with open('emoji_labels.json', 'r', encoding='utf-8') as f:
data = json.load(f)
texts = [item['text'] for item in data]
# 预计算向量
embeddings = model.encode(texts, normalize_embeddings=True, convert_to_tensor=True)
return data, embeddings
# 初始化数据
emoji_data, emoji_embeddings = load_data()
@app.get("/")
def home():
return "Kouri 5-Emotion System Ready"
@app.post("/match")
async def match_emoji(request: Request):
"""
输入: {"text": "我想吃汉堡"}
输出: {"label": "happy", "score": 0.85}
"""
try:
body = await request.json()
user_text = body.get("text", "")
# 兜底:空输入或者数据没加载好,返回 neutral
if not user_text or emoji_embeddings is None:
return {"label": "neutral", "score": 0.0}
# 构造查询 (BGE模型建议加前缀)
query_text = "为这个句子分类情感:" + user_text
query_emb = model.encode(query_text, normalize_embeddings=True, convert_to_tensor=True)
# 计算相似度
scores = util.cos_sim(query_emb, emoji_embeddings)[0]
best_score = float(torch.max(scores))
best_idx = int(torch.argmax(scores))
matched_item = emoji_data[best_idx]
# 打印日志
print(f"输入: {user_text} | 匹配: {matched_item['label']} | 分数: {best_score:.4f}")
# 逻辑判断
if best_score > THRESHOLD:
return {
"label": matched_item['label'], # 例如 "happy"
"score": best_score # 例如 0.8512
}
else:
# 分数太低,返回 neutral
return {
"label": "neutral",
"score": best_score
}
except Exception as e:
print(f"Error: {e}")
# 发生任何报错都返回 neutral,保证程序不崩
return {"label": "neutral", "score": 0.0} |