File size: 6,582 Bytes
c750faa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 |
import gradio as gr
import spacy
from transformers import ProphetNetTokenizer, ProphetNetForConditionalGeneration, pipeline
import torch
import time
import re
nlp = spacy.load("en_core_web_md")
MODEL_PATHS = {
"prophetnet2": "ManB2207540/prophetnet_SQuAD_1.1-2epoch_break",
"prophetnet tieu chuan": "microsoft/prophetnet-large-uncased-squad-qg"
}
def load_pipeline(model_path):
tokenizer = ProphetNetTokenizer.from_pretrained(model_path)
model = ProphetNetForConditionalGeneration.from_pretrained(model_path)
return pipeline(
"text2text-generation",
model=model,
tokenizer=tokenizer,
max_length=256,
num_return_sequences=1,
device=0 if torch.cuda.is_available() else -1
)
pipeline_cache = {}
def get_pipeline(model_name):
model_path = MODEL_PATHS[model_name]
if model_name not in pipeline_cache:
pipeline_cache[model_name] = load_pipeline(model_path)
return pipeline_cache[model_name]
# Tự viết hàm capitalize thông minh
def smart_capitalize(text):
# Giữ nguyên cách viết hoa phần còn lại, chỉ viết hoa chữ đầu nếu cần
text = text.strip()
if not text:
return text
text = text[0].upper() + text[1:]
if not re.search(r'[.?!]$', text):
text += '.'
return text
def generate_question(context, answer, model_name):
pipe = get_pipeline(model_name)
tokenizer = pipe.tokenizer
prompt = f"context: {context} answer: {answer}"
# Cắt prompt nếu vượt quá giới hạn token
encoded = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
input_ids = encoded["input_ids"]
attention_mask = encoded["attention_mask"]
try:
output = pipe.model.generate(
input_ids=input_ids.to(pipe.model.device),
attention_mask=attention_mask.to(pipe.model.device),
max_length=64,
num_return_sequences=1,
num_beams=4
)
result = pipe.tokenizer.decode(output[0], skip_special_tokens=True).strip()
result = smart_capitalize(result)
print(f"Generated question: {result}")
# Thêm dấu chấm nếu chưa có (và không kết thúc bằng ! hay ?)
if not re.search(r'[.?!]$', result):
result += '.'
return result
except Exception as e:
return f"Lỗi khi sinh câu hỏi: {e}"
def generate_qa_list(context, num_questions, model_choice):
doc = nlp(context)
entities = list(set([ent.text for ent in doc.ents]))
entities = [e for e in entities if len(e.strip().split()) <= 10]
if not entities:
return gr.update(visible=True), ["❌ Không tìm thấy thực thể nào để sinh câu hỏi."]
count = min(num_questions, len(entities))
qa_list = []
for i in range(count):
answer = entities[i]
question = generate_question(context, answer, model_choice)
answer = smart_capitalize(entities[i])
qa = f"**{question}**\n<details><summary>Hiện câu trả lời</summary><p>{answer}</p></details>"
qa_list.append(qa)
return gr.update(visible=True), qa_list
# Tách phần phân tích context và cập nhật slider
def analyze_context(context):
doc = nlp(context)
entities = list(set([ent.text for ent in doc.ents]))
entities = [e for e in entities if len(e.strip().split()) <= 10]
entity_count = len(entities)
if entity_count == 0:
return (
gr.update(visible=True, value="❌ Không tìm thấy thực thể nào để sinh câu hỏi."),
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False)
)
else:
return (
gr.update(visible=False),
gr.update(visible=True, maximum=entity_count, value=min(3, entity_count), label=f"Số câu hỏi (Tối đa: {entity_count})"),
gr.update(visible=True),
gr.update(visible=True)
)
with gr.Blocks() as demo:
gr.Markdown("## Hệ thống sinh câu hỏi từ Context bằng ProphetNet + spaCy NER")
with gr.Row():
with gr.Column(scale=4):
context_input = gr.Textbox(label="Nhập Context", lines=15, placeholder="Nhập đoạn văn bản...")
elapsed_time_md = gr.Markdown(visible=False)
with gr.Column(scale=1):
model_choice = gr.Dropdown(
label="Chọn mô hình",
choices=list(MODEL_PATHS.keys()),
value="prophetnet1"
)
num_input = gr.Slider(label="Số câu hỏi", minimum=1, maximum=5, value=3, step=1, visible=False)
generate_btn = gr.Button("Sinh câu hỏi", visible=False)
# Thông báo đang xử lý hoặc không tìm thấy
status_message = gr.Markdown(visible=False)
# Kết quả hiển thị tại đây
with gr.Column(visible=False) as output_container:
result_md_list = [gr.Markdown(visible=False) for _ in range(5)]
# Xử lý khi bấm nút sinh câu hỏi
def run_generation(context, num_questions, model_choice):
start_time = time.time()
visible_container, qa_list = generate_qa_list(context, num_questions, model_choice)
status_hide = gr.update(visible=False)
updates = []
for i in range(5):
if i < len(qa_list):
updates.append(gr.update(value=qa_list[i], visible=True))
else:
updates.append(gr.update(visible=False))
elapsed = time.time() - start_time
elapsed_msg = f"⏱️ Thời gian xử lý: {elapsed:.2f} giây"
elapsed_md = gr.update(value=elapsed_msg, visible=True)
return [status_hide, visible_container, elapsed_md] + updates
# Khi người dùng thay đổi context, tự động phân tích thực thể và cập nhật slider
context_input.change(
fn=analyze_context,
inputs=[context_input],
outputs=[status_message, num_input, generate_btn, elapsed_time_md]
)
def show_processing():
return gr.update(value="⏳ Đang xử lý...", visible=True)
generate_btn.click(
fn=show_processing,
inputs=[],
outputs=[status_message]
).then(
fn=run_generation,
inputs=[context_input, num_input, model_choice],
outputs=[status_message, output_container, elapsed_time_md] + result_md_list
)
demo.launch()
# #/Users/trantieuman/anaconda3/bin/python /Users/trantieuman/Documents/NLP/project/demo.py |