File size: 1,669 Bytes
401d4ac
 
 
 
3df7c63
401d4ac
 
 
 
3df7c63
 
401d4ac
3df7c63
 
401d4ac
3df7c63
 
401d4ac
3df7c63
 
401d4ac
3df7c63
 
 
 
 
401d4ac
 
 
 
3df7c63
 
7ffc1ff
3df7c63
 
 
 
 
 
 
 
 
 
 
 
 
401d4ac
59e48ac
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import gradio as gr
from transformers import BertTokenizer, BertForQuestionAnswering
import torch

# Load Chinese BERT-Large model fine-tuned for QA
model_name = "cgt/Roberta-wwm-ext-large-qa"
tokenizer = BertTokenizer.from_pretrained(model_name)
model = BertForQuestionAnswering.from_pretrained(model_name)

# QA function
def answer_question(context, question):
    try:
        inputs = tokenizer.encode_plus(question, context, return_tensors="pt", truncation=True)
        input_ids = inputs["input_ids"].tolist()[0]

        with torch.no_grad():
            outputs = model(**inputs)

        start_idx = torch.argmax(outputs.start_logits)
        end_idx = torch.argmax(outputs.end_logits) + 1

        answer = tokenizer.convert_tokens_to_string(
            tokenizer.convert_ids_to_tokens(input_ids[start_idx:end_idx])
        )

        return answer.strip() if answer.strip() else "⚠️ 抱歉,我无法从上下文中找到答案。"

    except Exception as e:
        return f"❌ 错误:{str(e)}"

# 创建 Gradio 界面
with gr.Blocks(title="中文BERT问答系统") as demo:
    gr.Markdown("## 请在下方分别输入上下文和问题。")

    with gr.Row():
        context_input = gr.Textbox(label="📝 上下文(Context)", placeholder="请输入参考内容……", lines=6)
        question_input = gr.Textbox(label="❓ 问题(Question)", placeholder="请输入你的问题……", lines=2)

    answer_output = gr.Textbox(label="📌 答案", lines=2)

    submit_btn = gr.Button("提交")

    submit_btn.click(fn=answer_question, inputs=[context_input, question_input], outputs=answer_output)

# 启动应用
demo.launch()