|
|
from fastapi import FastAPI, Request |
|
|
from sentence_transformers import SentenceTransformer, util |
|
|
import json |
|
|
import torch |
|
|
import os |
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
|
|
|
THRESHOLD = 0.35 |
|
|
|
|
|
print("正在加载 BGE-Large-ZH-v1.5...") |
|
|
|
|
|
model = SentenceTransformer('BAAI/bge-large-zh-v1.5') |
|
|
print("模型加载完成") |
|
|
|
|
|
def load_data(): |
|
|
if not os.path.exists('emoji_labels.json'): |
|
|
print("警告: 找不到 emoji_labels.json") |
|
|
return [], None |
|
|
with open('emoji_labels.json', 'r', encoding='utf-8') as f: |
|
|
data = json.load(f) |
|
|
texts = [item['text'] for item in data] |
|
|
|
|
|
embeddings = model.encode(texts, normalize_embeddings=True, convert_to_tensor=True) |
|
|
return data, embeddings |
|
|
|
|
|
|
|
|
emoji_data, emoji_embeddings = load_data() |
|
|
|
|
|
@app.get("/") |
|
|
def home(): |
|
|
return "Kouri 5-Emotion System Ready" |
|
|
|
|
|
@app.post("/match") |
|
|
async def match_emoji(request: Request): |
|
|
""" |
|
|
输入: {"text": "我想吃汉堡"} |
|
|
输出: {"label": "happy", "score": 0.85} |
|
|
""" |
|
|
try: |
|
|
body = await request.json() |
|
|
user_text = body.get("text", "") |
|
|
|
|
|
|
|
|
if not user_text or emoji_embeddings is None: |
|
|
return {"label": "neutral", "score": 0.0} |
|
|
|
|
|
|
|
|
query_text = "为这个句子分类情感:" + user_text |
|
|
query_emb = model.encode(query_text, normalize_embeddings=True, convert_to_tensor=True) |
|
|
|
|
|
|
|
|
scores = util.cos_sim(query_emb, emoji_embeddings)[0] |
|
|
best_score = float(torch.max(scores)) |
|
|
best_idx = int(torch.argmax(scores)) |
|
|
|
|
|
matched_item = emoji_data[best_idx] |
|
|
|
|
|
|
|
|
print(f"输入: {user_text} | 匹配: {matched_item['label']} | 分数: {best_score:.4f}") |
|
|
|
|
|
|
|
|
if best_score > THRESHOLD: |
|
|
return { |
|
|
"label": matched_item['label'], |
|
|
"score": best_score |
|
|
} |
|
|
else: |
|
|
|
|
|
return { |
|
|
"label": "neutral", |
|
|
"score": best_score |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error: {e}") |
|
|
|
|
|
return {"label": "neutral", "score": 0.0} |