JemeinAI commited on
Commit
401d4ac
·
verified ·
1 Parent(s): 471cb28

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +93 -0
  2. requirements.txt +5 -0
app.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import BertTokenizer, BertForQuestionAnswering
3
+ import torch
4
+ import fitz # PyMuPDF for PDF
5
+ import docx # python-docx for Word files
6
+
7
+ # Load model and tokenizer
8
+ model_name = "cgt/Roberta-wwm-ext-large-qa"
9
+ tokenizer = BertTokenizer.from_pretrained(model_name)
10
+ model = BertForQuestionAnswering.from_pretrained(model_name)
11
+
12
+ # Token limit for BERT (typically 512)
13
+ MAX_TOKENS = 512
14
+
15
+ # Extract text from files
16
+ def extract_text_from_file(file):
17
+ if file.name.endswith(".txt"):
18
+ return file.read().decode("utf-8")
19
+ elif file.name.endswith(".pdf"):
20
+ text = ""
21
+ doc = fitz.open(stream=file.read(), filetype="pdf")
22
+ for page in doc:
23
+ text += page.get_text()
24
+ return text
25
+ elif file.name.endswith(".docx"):
26
+ doc = docx.Document(file)
27
+ return "\n".join([para.text for para in doc.paragraphs])
28
+ else:
29
+ return "❌ 不支持的文件格式"
30
+
31
+ # Chunk large context
32
+ def chunk_text(text, max_length=MAX_TOKENS):
33
+ tokens = tokenizer.tokenize(text)
34
+ chunks = []
35
+ for i in range(0, len(tokens), max_length - 50): # leave room for question
36
+ chunk = tokens[i:i + max_length - 50]
37
+ chunks.append(tokenizer.convert_tokens_to_string(chunk))
38
+ return chunks
39
+
40
+ # QA function
41
+ def answer_question(context, question, file):
42
+ try:
43
+ if file:
44
+ context = extract_text_from_file(file)
45
+
46
+ if not context or not question:
47
+ return "⚠️ 请提供上下文和问题。"
48
+
49
+ best_answer = ""
50
+ best_score = -float("inf")
51
+
52
+ chunks = chunk_text(context)
53
+
54
+ for chunk in chunks:
55
+ inputs = tokenizer.encode_plus(question, chunk, return_tensors="pt", truncation=True)
56
+ input_ids = inputs["input_ids"].tolist()[0]
57
+
58
+ with torch.no_grad():
59
+ outputs = model(**inputs)
60
+
61
+ start_idx = torch.argmax(outputs.start_logits)
62
+ end_idx = torch.argmax(outputs.end_logits) + 1
63
+
64
+ answer = tokenizer.convert_tokens_to_string(
65
+ tokenizer.convert_ids_to_tokens(input_ids[start_idx:end_idx])
66
+ )
67
+
68
+ score = outputs.start_logits[0][start_idx] + outputs.end_logits[0][end_idx - 1]
69
+ if score > best_score and answer.strip():
70
+ best_answer = answer.strip()
71
+ best_score = score
72
+
73
+ return best_answer if best_answer else "🤔 没能从上下文中找到明确答案。"
74
+
75
+ except Exception as e:
76
+ return f"❌ 错误:{str(e)}"
77
+
78
+ # Gradio Interface
79
+ with gr.Blocks(title="中文BERT问答系统(含文档上传)") as demo:
80
+ gr.Markdown("## 📘 中文BERT问答系统\n支持 `.txt`、`.pdf`、`.docx` 文档上传或手动输入上下文。")
81
+
82
+ with gr.Row():
83
+ context_input = gr.Textbox(label="📝 上下文(可选)", placeholder="或上传文件", lines=6)
84
+ file_input = gr.File(label="📂 上传文档", file_types=[".txt", ".pdf", ".docx"])
85
+
86
+ question_input = gr.Textbox(label="❓ 问题", placeholder="请输入问题", lines=2)
87
+ answer_output = gr.Textbox(label="📌 答案", lines=3)
88
+ submit_btn = gr.Button("提交")
89
+
90
+ submit_btn.click(fn=answer_question, inputs=[context_input, question_input, file_input], outputs=answer_output)
91
+
92
+ # 启动应用
93
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ gradio
2
+ transformers
3
+ torch
4
+ PyMuPDF
5
+ python-docx