manipulative-detector / src /streamlit_app.py
LilithHu's picture
Update src/streamlit_app.py
4c3b5cd verified
import streamlit as st
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
import requests
# ====================== 页面配置 ======================
st.set_page_config(
page_title="🧠 Manipulative Detector",
page_icon="🧠",
layout="centered",
)
st.markdown("""
<style>
html, body, [class*="css"] {
font-family: 'Segoe UI', sans-serif;
background-color: #f2f2f7;
}
.main {
background-color: white;
padding: 2rem;
border-radius: 16px;
box-shadow: 0 4px 20px rgba(0,0,0,0.1);
}
.stTextArea textarea {
background-color: #fdfdfd !important;
border-radius: 12px !important;
padding: 12px !important;
}
.stButton button {
background-color: #3b82f6 !important;
color: white !important;
font-weight: 600;
border-radius: 8px;
padding: 0.5rem 1.5rem;
}
.stButton button:hover {
background-color: #2563eb !important;
}
</style>
""", unsafe_allow_html=True)
# ====================== 模型加载 ======================
@st.cache_resource
def load_model():
model_name = "LilithHu/mbert-manipulative-detector"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
model.eval()
return tokenizer, model
tokenizer, model = load_model()
st.sidebar.success("✅ 最新模型已成功加载!")
# ====================== 多语言支持 ======================
lang = st.sidebar.selectbox("Language / 语言", ["English", "中文"])
st.sidebar.markdown("---")
# ====================== 标题 ======================
if lang == "English":
st.title("🧠 Manipulative Language Detector")
st.markdown("This tool uses an AI model to detect manipulative language in messages.")
else:
st.title("🧠 情感操控语言识别器")
st.markdown("本工具使用 AI 模型检测文本中的情感操控语言。")
st.markdown("---")
# ====================== 用户输入 ======================
user_input = st.text_area("Enter your message / 输入文本", height=150)
# ====================== 推理函数 ======================
def predict(text):
inputs = tokenizer(text, truncation=True, padding='max_length', max_length=128, return_tensors='pt')
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
predicted_class = torch.argmax(logits, dim=1).item()
return predicted_class
# ====================== 按钮 & 结果显示 ======================
if st.button("🔍 开始分析 / Analyze"):
if not user_input.strip():
st.warning("⚠️ 请输入文本!" if lang == "中文" else "⚠️ Please enter some text!")
else:
with st.spinner("正在分析..." if lang == "中文" else "Analyzing..."):
try:
prediction = predict(user_input)
label = "操纵性语言" if prediction == 1 else "非操纵语言"
if prediction == 1:
st.markdown(f"""
<div style='background-color:#fee2e2; padding:20px; border-radius:12px; border: 1px solid #fca5a5;'>
<h3 style='color:#b91c1c;'>⚠️ {label}</h3>
<p>{'该文本可能存在操纵意图,请谨慎使用。' if lang == '中文' else 'The text may contain manipulative intent. Use caution.'}</p>
</div>
""", unsafe_allow_html=True)
else:
st.markdown(f"""
<div style='background-color:#d1fae5; padding:20px; border-radius:12px; border: 1px solid #6ee7b7;'>
<h3 style='color:#065f46;'>✅ {label}</h3>
<p>{'文本未检测到操纵意图,属于正常交流。' if lang == '中文' else 'No manipulative intent detected. The message seems fine.'}</p>
</div>
""", unsafe_allow_html=True)
except Exception as e:
st.error(f"❌ 错误: {e}")
# ====================== 页脚 ======================
st.markdown("---")
st.markdown(
"<p style='text-align: center; color: #888;'>© 2025 Manipulative Language Detector</p>",
unsafe_allow_html=True
)