Update app.py
Browse files
app.py
CHANGED
|
@@ -6,10 +6,14 @@ import joblib
|
|
| 6 |
import numpy as np
|
| 7 |
import librosa
|
| 8 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
| 9 |
from huggingface_hub import hf_hub_download
|
| 10 |
from deepface import DeepFace
|
| 11 |
from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
|
| 12 |
|
|
|
|
| 13 |
# --- 1. 下載並載入 SVM 模型 ---
|
| 14 |
# 這裡 repo_id 填你的模型倉庫路徑,例如 "GCLing/emotion-svm-model"
|
| 15 |
# filename 填上傳到該倉庫的檔案名,例如 "svm_emotion_model.joblib"
|
|
@@ -20,12 +24,21 @@ svm_model = joblib.load(model_path)
|
|
| 20 |
print("SVM model loaded.")
|
| 21 |
|
| 22 |
# --- 2. 載入文字情緒分析模型 ---
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
|
| 30 |
# --- 3. 聲音特徵擷取函式 ---
|
| 31 |
def extract_feature(signal: np.ndarray, sr: int) -> np.ndarray:
|
|
@@ -89,22 +102,44 @@ def predict_voice(audio_path: str):
|
|
| 89 |
|
| 90 |
|
| 91 |
|
| 92 |
-
def
|
| 93 |
-
|
|
|
|
|
|
|
|
|
|
| 94 |
if not text or text.strip() == "":
|
| 95 |
return {}
|
| 96 |
-
#
|
| 97 |
-
|
| 98 |
-
if
|
| 99 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
try:
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
|
|
|
|
|
|
|
|
|
| 104 |
return result
|
| 105 |
except Exception as e:
|
| 106 |
-
print("
|
| 107 |
-
return {}
|
|
|
|
|
|
|
|
|
|
| 108 |
|
| 109 |
|
| 110 |
# --- 5. 建立 Gradio 介面 ---
|
|
@@ -128,10 +163,11 @@ with gr.Blocks() as demo:
|
|
| 128 |
audio.change(fn=predict_voice, inputs=audio, outputs=audio_output)
|
| 129 |
|
| 130 |
with gr.TabItem("文字情緒"):
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
|
|
|
| 135 |
|
| 136 |
|
| 137 |
|
|
|
|
| 6 |
import numpy as np
|
| 7 |
import librosa
|
| 8 |
import gradio as gr
|
| 9 |
+
import time
|
| 10 |
+
import re
|
| 11 |
+
from transformers import pipeline
|
| 12 |
from huggingface_hub import hf_hub_download
|
| 13 |
from deepface import DeepFace
|
| 14 |
from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
|
| 15 |
|
| 16 |
+
|
| 17 |
# --- 1. 下載並載入 SVM 模型 ---
|
| 18 |
# 這裡 repo_id 填你的模型倉庫路徑,例如 "GCLing/emotion-svm-model"
|
| 19 |
# filename 填上傳到該倉庫的檔案名,例如 "svm_emotion_model.joblib"
|
|
|
|
| 24 |
print("SVM model loaded.")
|
| 25 |
|
| 26 |
# --- 2. 載入文字情緒分析模型 ---
|
| 27 |
+
zero_shot = pipeline("zero-shot-classification", model="joeddav/xlm-roberta-large-xnli")
|
| 28 |
+
candidate_labels = ["joy", "sadness", "anger", "fear", "surprise", "disgust"]
|
| 29 |
+
label_map_en2cn = {
|
| 30 |
+
"joy": "高興", "sadness": "悲傷", "anger": "憤怒",
|
| 31 |
+
"fear": "恐懼", "surprise": "驚訝", "disgust": "厭惡"
|
| 32 |
+
}
|
| 33 |
+
emo_keywords = {
|
| 34 |
+
"happy": ["開心","快樂","愉快","喜悦","喜悅","歡喜","興奮","高興"],
|
| 35 |
+
"angry": ["生氣","憤怒","不爽","發火","火大","氣憤"],
|
| 36 |
+
"sad": ["傷心","難過","哭","難受","心酸","憂","悲","哀","痛苦","慘","愁"],
|
| 37 |
+
"surprise": ["驚訝","意外","嚇","驚詫","詫異","訝異","好奇"],
|
| 38 |
+
"fear": ["怕","恐懼","緊張","懼","膽怯","畏"]
|
| 39 |
+
}
|
| 40 |
+
# 简单否定词列表
|
| 41 |
+
negations = ["不","沒","沒有","別","勿","非"]
|
| 42 |
|
| 43 |
# --- 3. 聲音特徵擷取函式 ---
|
| 44 |
def extract_feature(signal: np.ndarray, sr: int) -> np.ndarray:
|
|
|
|
| 102 |
|
| 103 |
|
| 104 |
|
| 105 |
+
def predict_text_mixed(text: str):
|
| 106 |
+
"""
|
| 107 |
+
先用 keyword_emotion 规则;若未命中再用 zero-shot 分类,
|
| 108 |
+
返回 {中文标签: float_score} 的 dict,供 gr.Label 显示。
|
| 109 |
+
"""
|
| 110 |
if not text or text.strip() == "":
|
| 111 |
return {}
|
| 112 |
+
# 规则优先
|
| 113 |
+
res = keyword_emotion(text)
|
| 114 |
+
if res:
|
| 115 |
+
# 只返回最高那一项及其比例,也可返回完整分布
|
| 116 |
+
top_emo = max(res, key=res.get)
|
| 117 |
+
# 可将英文 key 转成中文,若需要
|
| 118 |
+
# mapping: happy->高兴, angry->愤怒, etc.
|
| 119 |
+
mapping = {
|
| 120 |
+
"happy": "高兴",
|
| 121 |
+
"angry": "愤怒",
|
| 122 |
+
"sad": "悲伤",
|
| 123 |
+
"surprise": "惊讶",
|
| 124 |
+
"fear": "恐惧"
|
| 125 |
+
}
|
| 126 |
+
cn = mapping.get(top_emo, top_emo)
|
| 127 |
+
return {cn: res[top_emo]}
|
| 128 |
+
# 规则未命中,zero-shot fallback
|
| 129 |
try:
|
| 130 |
+
out = zero_shot(text, candidate_labels=candidate_labels,
|
| 131 |
+
hypothesis_template="这句话表达了{}情绪")
|
| 132 |
+
result = {}
|
| 133 |
+
for lab, sc in zip(out["labels"], out["scores"]):
|
| 134 |
+
cn = label_map_en2cn.get(lab.lower(), lab)
|
| 135 |
+
result[cn] = float(sc)
|
| 136 |
return result
|
| 137 |
except Exception as e:
|
| 138 |
+
print("zero-shot error:", e)
|
| 139 |
+
return {"中性": 1.0}
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
|
| 143 |
|
| 144 |
|
| 145 |
# --- 5. 建立 Gradio 介面 ---
|
|
|
|
| 163 |
audio.change(fn=predict_voice, inputs=audio, outputs=audio_output)
|
| 164 |
|
| 165 |
with gr.TabItem("文字情緒"):
|
| 166 |
+
gr.Markdown("### 文字情緒 分析 (规则+zero-shot)")
|
| 167 |
+
with gr.Row():
|
| 168 |
+
text = gr.Textbox(lines=3, placeholder="請輸入中文文字…")
|
| 169 |
+
text_out = gr.Label(label="文字情緒結果")
|
| 170 |
+
text.submit(fn=predict_text_mixed, inputs=text, outputs=text_out)
|
| 171 |
|
| 172 |
|
| 173 |
|