Classify / app.py
strongeryongchao's picture
Update app.py
849938b verified
# app.py
import gradio as gr
from transformers import pipeline
import pandas as pd
import tempfile
import os
# 初始化情感分析模型
model_name = "uer/roberta-base-finetuned-jd-binary-chinese"
classifier = pipeline("text-classification", model=model_name)
def predict_single(text):
"""单句预测并格式化结果为PDF要求的样式"""
if not text.strip():
return {"negative (stars 1, 2 and 3)": "0.0000", "positive (stars 4 and 5)": "0.0000"}
result = classifier(text)[0]
return {
"negative (stars 1, 2 and 3)": f"{result['score']:.6f}" if result['label'] == 'positive' else f"{1 - result['score']:.6f}",
"positive (stars 4 and 5)": f"{1 - result['score']:.6f}" if result['label'] == 'positive' else f"{result['score']:.6f}"
}
def process_batch(file):
"""处理批量文本并生成带完整结果的CSV"""
if not file:
raise gr.Error("请先上传文件")
with open(file.name, "r", encoding="utf-8") as f:
texts = [line.strip() for line in f if line.strip()]
results = []
for text in texts:
pred = classifier(text)[0]
results.append({
"文本": text,
"negative (stars 1, 2 and 3)": f"{pred['score']:.6f}" if pred['label'] == 'positive' else f"{1 - pred['score']:.6f}",
"positive (stars 4 and 5)": f"{1 - pred['score']:.6f}" if pred['label'] == 'positive' else f"{pred['score']:.6f}"
})
df = pd.DataFrame(results)
# 创建临时CSV文件
temp_dir = tempfile.mkdtemp()
csv_path = os.path.join(temp_dir, "情感分析结果.csv")
df.to_csv(csv_path, index=False, encoding="utf-8-sig")
return df, csv_path
# =============== 界面布局 ===============
with gr.Blocks(title="基于RoBERTa模型的中文情感分析系统", css="""
.gradio-container {max-width: 900px;
margin:0 auto !important;
padding:20px;
}
.df-table {width: 100%;
}
hi, .markdown-text{
text-align:center !important;
}
"""
) as demo:
# 标题和说明
gr.Markdown("""
# 基于RoBERTa模型的中文情感分析系统
本项目使用RoBERTa中文情感分类模型,实现积极与消极情绪的二分类分析
**标签说明**:
positive:积极情绪(相当于评分4星和5星)
negative:消极情绪(相当于评分1星、2星和3星)
""")
# ===== 单句分析标签页 =====
with gr.Tab("单句情感分析"):
gr.Markdown("### 请输入一句话进行情感分析")
text_input = gr.Textbox(label="", placeholder="请输入要分析的文本...", lines=3)
analyze_btn = gr.Button("分析", variant="primary")
# 结果展示表格
gr.Markdown("**分析结果**")
single_result = gr.Dataframe(
headers=["negative (stars 1, 2 and 3)", "positive (stars 4 and 5)"],
datatype=["str", "str"],
elem_classes=["df-table"],
interactive=False
)
# ===== 批量分析标签页 =====
with gr.Tab("批量情感分析"):
gr.Markdown("### 上传.txt文件(每行一句)")
file_input = gr.File(label="", file_types=[".txt"])
process_btn = gr.Button("上传分析", variant="primary")
# 结果展示
gr.Markdown("**分析结果:**")
batch_result = gr.Dataframe(
headers=["文本", "negative (stars 1, 2 and 3)", "positive (stars 4 and 5)"],
elem_classes=["df-table"],
interactive=False
)
download_btn = gr.DownloadButton("下载分析结果(CSV)", visible=False)
temp_file = gr.State()
# ===== 事件处理 =====
# 单句分析(修复卡死的关键修改)
@analyze_btn.click(inputs=text_input, outputs=single_result)
def handle_single_analysis(text):
if not text.strip():
return pd.DataFrame([{
"negative (stars 1, 2 and 3)": "N/A",
"positive (stars 4 and 5)": "N/A"
}])
result = predict_single(text)
return pd.DataFrame([result])
# 批量分析
@process_btn.click(inputs=file_input, outputs=[batch_result, temp_file, download_btn])
def handle_batch_analysis(file):
df, csv_path = process_batch(file)
return df, csv_path, gr.update(visible=True)
# 文件下载
download_btn.click(
fn=lambda x: x,
inputs=temp_file,
outputs=download_btn
)
# 页脚(居中)
gr.Markdown("---")
gr.Markdown("燕山大学文法学院 姜,联系方式: myhopejob@163.com")
# 启动应用
demo.launch(share=True)