Update app.py
Browse files
app.py
CHANGED
|
@@ -10,18 +10,22 @@ app = FastAPI()
|
|
| 10 |
THRESHOLD = 0.35
|
| 11 |
|
| 12 |
print("正在加载 BGE-Large-ZH-v1.5...")
|
|
|
|
| 13 |
model = SentenceTransformer('BAAI/bge-large-zh-v1.5')
|
| 14 |
print("模型加载完成")
|
| 15 |
|
| 16 |
def load_data():
|
| 17 |
if not os.path.exists('emoji_labels.json'):
|
|
|
|
| 18 |
return [], None
|
| 19 |
with open('emoji_labels.json', 'r', encoding='utf-8') as f:
|
| 20 |
data = json.load(f)
|
| 21 |
texts = [item['text'] for item in data]
|
|
|
|
| 22 |
embeddings = model.encode(texts, normalize_embeddings=True, convert_to_tensor=True)
|
| 23 |
return data, embeddings
|
| 24 |
|
|
|
|
| 25 |
emoji_data, emoji_embeddings = load_data()
|
| 26 |
|
| 27 |
@app.get("/")
|
|
@@ -38,11 +42,11 @@ async def match_emoji(request: Request):
|
|
| 38 |
body = await request.json()
|
| 39 |
user_text = body.get("text", "")
|
| 40 |
|
| 41 |
-
# 兜底:空输入返回 neutral
|
| 42 |
if not user_text or emoji_embeddings is None:
|
| 43 |
return {"label": "neutral", "score": 0.0}
|
| 44 |
|
| 45 |
-
# 构造查询
|
| 46 |
query_text = "为这个句子分类情感:" + user_text
|
| 47 |
query_emb = model.encode(query_text, normalize_embeddings=True, convert_to_tensor=True)
|
| 48 |
|
|
@@ -63,8 +67,13 @@ async def match_emoji(request: Request):
|
|
| 63 |
"score": best_score # 例如 0.8512
|
| 64 |
}
|
| 65 |
else:
|
| 66 |
-
# 分数太低,返回 neutral
|
| 67 |
return {
|
| 68 |
"label": "neutral",
|
| 69 |
"score": best_score
|
| 70 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
THRESHOLD = 0.35
|
| 11 |
|
| 12 |
print("正在加载 BGE-Large-ZH-v1.5...")
|
| 13 |
+
# 这里会自动下载模型,如果日志卡在这里请耐心等待
|
| 14 |
model = SentenceTransformer('BAAI/bge-large-zh-v1.5')
|
| 15 |
print("模型加载完成")
|
| 16 |
|
| 17 |
def load_data():
|
| 18 |
if not os.path.exists('emoji_labels.json'):
|
| 19 |
+
print("警告: 找不到 emoji_labels.json")
|
| 20 |
return [], None
|
| 21 |
with open('emoji_labels.json', 'r', encoding='utf-8') as f:
|
| 22 |
data = json.load(f)
|
| 23 |
texts = [item['text'] for item in data]
|
| 24 |
+
# 预计算向量
|
| 25 |
embeddings = model.encode(texts, normalize_embeddings=True, convert_to_tensor=True)
|
| 26 |
return data, embeddings
|
| 27 |
|
| 28 |
+
# 初始化数据
|
| 29 |
emoji_data, emoji_embeddings = load_data()
|
| 30 |
|
| 31 |
@app.get("/")
|
|
|
|
| 42 |
body = await request.json()
|
| 43 |
user_text = body.get("text", "")
|
| 44 |
|
| 45 |
+
# 兜底:空输入或者数据没加载好,返回 neutral
|
| 46 |
if not user_text or emoji_embeddings is None:
|
| 47 |
return {"label": "neutral", "score": 0.0}
|
| 48 |
|
| 49 |
+
# 构造查询 (BGE模型建议加前缀)
|
| 50 |
query_text = "为这个句子分类情感:" + user_text
|
| 51 |
query_emb = model.encode(query_text, normalize_embeddings=True, convert_to_tensor=True)
|
| 52 |
|
|
|
|
| 67 |
"score": best_score # 例如 0.8512
|
| 68 |
}
|
| 69 |
else:
|
| 70 |
+
# 分数太低,返回 neutral
|
| 71 |
return {
|
| 72 |
"label": "neutral",
|
| 73 |
"score": best_score
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
except Exception as e:
|
| 77 |
+
print(f"Error: {e}")
|
| 78 |
+
# 发生任何报错都返回 neutral,保证程序不崩
|
| 79 |
+
return {"label": "neutral", "score": 0.0}
|