bert-app-v1 / app.py
yzh621's picture
Update app.py
7a5ec60 verified
import os
import torch
import gradio as gr
from transformers import pipeline, AutoModelForQuestionAnswering, AutoTokenizer, AutoModelForSequenceClassification
# 1. 基础环境设置
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["TF_USE_LEGACY_KERAS"] = "1"
print("正在初始化 AI 模型...")
# ----------------------------------------------------------------
# 1. 完形填空功能
# ----------------------------------------------------------------
try:
unmasker = pipeline('fill-mask', model='bert-base-chinese', device=-1)
except Exception as e:
print(f"Fill-Mask 加载警告: {e}")
unmasker = None
def fill_mask_ai(text):
if unmasker is None: return "模型加载出错"
if "[MASK]" not in text: return "⚠️ 错误:请在句子中包含 [MASK] 符号"
try:
results = unmasker(text)
output_str = ""
for idx, res in enumerate(results):
score = res['score'] * 100
token = res['token_str']
output_str += f"第 {idx+1} 名: 【{token}】 (置信度: {score:.1f}%)\n"
return output_str
except Exception as e: return f"运行出错: {e}"
# ----------------------------------------------------------------
# 2. 阅读理解功能 (手动加载防报错)
# ----------------------------------------------------------------
qa_model_name = "uer/roberta-base-chinese-extractive-qa"
try:
print("正在手动加载 QA 模型...")
qa_model = AutoModelForQuestionAnswering.from_pretrained(qa_model_name)
qa_tokenizer = AutoTokenizer.from_pretrained(qa_model_name)
print("QA 模型加载成功!(原生模式)")
except Exception as e:
print(f"QA 模型加载失败: {e}")
qa_model = None
qa_tokenizer = None
def reading_comprehension(context, question):
if qa_model is None: return "QA 模型未能成功加载"
if not context or not question: return "请填写完整的文章和问题。"
try:
inputs = qa_tokenizer(question, context, return_tensors="pt")
with torch.no_grad():
outputs = qa_model(**inputs)
answer_start_index = outputs.start_logits.argmax()
answer_end_index = outputs.end_logits.argmax()
predict_answer_tokens = inputs.input_ids[0, answer_start_index : answer_end_index + 1]
answer = qa_tokenizer.decode(predict_answer_tokens, skip_special_tokens=True)
if not answer.strip():
return "(未能找到明确答案)"
start_score = outputs.start_logits.max().item()
end_score = outputs.end_logits.max().item()
confidence = (start_score + end_score) / 2.0
return f"🤔 AI 的回答:{answer}\n(模型得分: {confidence:.2f})"
except Exception as e:
return f"推理出错: {e}"
# ----------------------------------------------------------------
# 3. 情感分析功能 (修正了之前可能弄反的逻辑)
# ----------------------------------------------------------------
senti_model_name = "uer/roberta-base-finetuned-dianping-chinese"
try:
print("正在手动加载情感分析模型...")
senti_model = AutoModelForSequenceClassification.from_pretrained(senti_model_name)
senti_tokenizer = AutoTokenizer.from_pretrained(senti_model_name)
sentiment_pipeline = pipeline('text-classification', model=senti_model, tokenizer=senti_tokenizer, device=-1)
except Exception as e:
print(f"情感分析加载失败: {e}")
sentiment_pipeline = None
def sentiment_analysis(text):
if sentiment_pipeline is None: return "模型未能加载"
try:
result = sentiment_pipeline(text)[0]
label_text = str(result['label']).lower() # 转为小写,防止大小写干扰
score = result['score'] * 100
# Debug: 在后台打印出到底是什么标签,方便调试
print(f"输入: {text} | 原始标签: {label_text} | 分数: {score}")
# 【修正后的严格判断逻辑】
# 这个模型通常输出 "positive (5 stars)" 或 "negative (1 star)"
# 之前的逻辑可能对 "labels_0/1" 判断有误,现在改为关键词强匹配
if 'positive' in label_text:
label = "😊 正面/积极"
elif '5 star' in label_text or '4 star' in label_text:
label = "😊 正面/积极"
elif 'neutral' in label_text: # 极少数情况有中性
label = "😐 中性/平和"
else:
# 剩下的 negative, 1 star, 2 star, 3 star 全都算消极
label = "😡 负面/消极"
return f"分析结果:{label}\n强度:{score:.1f}%"
except Exception as e: return f"分析出错: {e}"
# ----------------------------------------------------------------
# 4. 界面构建
# ----------------------------------------------------------------
custom_css = """
body, .gradio-container, .prose, input, button, textarea, span, label {
font-family: 'SimSun', 'STSong', 'Songti SC', serif !important;
}
"""
with gr.Blocks(title="BERT Playground", theme=gr.themes.Soft(), css=custom_css) as demo:
gr.Markdown("# 🤖 BERT 语言模型AI助手")
gr.Markdown("\n\n基于 Google BERT 及其变体模型构建的中文 AI 演示")
with gr.Tab("🧩 完形填空 (Fill Mask)"):
gr.Markdown("输入一句话,用 `[MASK]` 代替你想让 AI 猜的词。")
input_mask = gr.Textbox(label="输入句子", value="我要打王者[MASK]耀。")
btn_mask = gr.Button("开始猜词", variant="primary")
output_mask = gr.Textbox(label="AI 的猜测")
btn_mask.click(fill_mask_ai, inputs=input_mask, outputs=output_mask)
with gr.Tab("📖 阅读理解 (Q&A)"):
gr.Markdown("粘贴一段短文,然后问 AI 一个问题。")
default_context = """A:“小明,你的牙齿真好看!”
B:“哦,那是假的!”
A:“啊?真的假的?”
B:“真的"""
input_context = gr.Textbox(label="文章 (Context)", lines=5, value=default_context)
input_question = gr.Textbox(label="你的问题 (Question)", value="小明的牙齿是真的还是假的?")
btn_qa = gr.Button("寻找答案", variant="primary")
output_qa = gr.Textbox(label="BERT 的对答")
btn_qa.click(reading_comprehension, inputs=[input_context, input_question], outputs=output_qa)
with gr.Tab("❤️ 情感分析 (Sentiment)"):
gr.Markdown("输入一句话,AI 判断语气是积极还是消极。")
input_senti = gr.Textbox(label="输入评价", value="我不想上早八。")
btn_senti = gr.Button("分析情绪", variant="primary")
output_senti = gr.Textbox(label="分析结果")
btn_senti.click(sentiment_analysis, inputs=input_senti, outputs=output_senti)
if __name__ == "__main__":
demo.launch(ssr_mode=False)