bert_Chinese / app.py
JemeinAI's picture
Update app.py
7ffc1ff verified
import gradio as gr
from transformers import BertTokenizer, BertForQuestionAnswering
import torch
# Load Chinese BERT-Large model fine-tuned for QA
model_name = "cgt/Roberta-wwm-ext-large-qa"
tokenizer = BertTokenizer.from_pretrained(model_name)
model = BertForQuestionAnswering.from_pretrained(model_name)
# QA function
def answer_question(context, question):
try:
inputs = tokenizer.encode_plus(question, context, return_tensors="pt", truncation=True)
input_ids = inputs["input_ids"].tolist()[0]
with torch.no_grad():
outputs = model(**inputs)
start_idx = torch.argmax(outputs.start_logits)
end_idx = torch.argmax(outputs.end_logits) + 1
answer = tokenizer.convert_tokens_to_string(
tokenizer.convert_ids_to_tokens(input_ids[start_idx:end_idx])
)
return answer.strip() if answer.strip() else "⚠️ 抱歉,我无法从上下文中找到答案。"
except Exception as e:
return f"❌ 错误:{str(e)}"
# 创建 Gradio 界面
with gr.Blocks(title="中文BERT问答系统") as demo:
gr.Markdown("## 请在下方分别输入上下文和问题。")
with gr.Row():
context_input = gr.Textbox(label="📝 上下文(Context)", placeholder="请输入参考内容……", lines=6)
question_input = gr.Textbox(label="❓ 问题(Question)", placeholder="请输入你的问题……", lines=2)
answer_output = gr.Textbox(label="📌 答案", lines=2)
submit_btn = gr.Button("提交")
submit_btn.click(fn=answer_question, inputs=[context_input, question_input], outputs=answer_output)
# 启动应用
demo.launch()