|
|
|
|
|
import gradio as gr |
|
|
import spacy |
|
|
from transformers import ProphetNetTokenizer, ProphetNetForConditionalGeneration, pipeline |
|
|
import torch |
|
|
import time |
|
|
import re |
|
|
|
|
|
nlp = spacy.load("en_core_web_md") |
|
|
|
|
|
MODEL_PATHS = { |
|
|
"prophetnet2": "ManB2207540/prophetnet_SQuAD_1.1-2epoch_break", |
|
|
"prophetnet tieu chuan": "microsoft/prophetnet-large-uncased-squad-qg" |
|
|
} |
|
|
|
|
|
def load_pipeline(model_path): |
|
|
tokenizer = ProphetNetTokenizer.from_pretrained(model_path) |
|
|
model = ProphetNetForConditionalGeneration.from_pretrained(model_path) |
|
|
return pipeline( |
|
|
"text2text-generation", |
|
|
model=model, |
|
|
tokenizer=tokenizer, |
|
|
max_length=256, |
|
|
num_return_sequences=1, |
|
|
device=0 if torch.cuda.is_available() else -1 |
|
|
) |
|
|
|
|
|
pipeline_cache = {} |
|
|
|
|
|
def get_pipeline(model_name): |
|
|
model_path = MODEL_PATHS[model_name] |
|
|
if model_name not in pipeline_cache: |
|
|
pipeline_cache[model_name] = load_pipeline(model_path) |
|
|
return pipeline_cache[model_name] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def smart_capitalize(text): |
|
|
|
|
|
text = text.strip() |
|
|
if not text: |
|
|
return text |
|
|
text = text[0].upper() + text[1:] |
|
|
if not re.search(r'[.?!]$', text): |
|
|
text += '.' |
|
|
return text |
|
|
|
|
|
def generate_question(context, answer, model_name): |
|
|
pipe = get_pipeline(model_name) |
|
|
tokenizer = pipe.tokenizer |
|
|
prompt = f"context: {context} answer: {answer}" |
|
|
|
|
|
|
|
|
encoded = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512) |
|
|
input_ids = encoded["input_ids"] |
|
|
attention_mask = encoded["attention_mask"] |
|
|
|
|
|
try: |
|
|
output = pipe.model.generate( |
|
|
input_ids=input_ids.to(pipe.model.device), |
|
|
attention_mask=attention_mask.to(pipe.model.device), |
|
|
max_length=64, |
|
|
num_return_sequences=1, |
|
|
num_beams=4 |
|
|
) |
|
|
result = pipe.tokenizer.decode(output[0], skip_special_tokens=True).strip() |
|
|
result = smart_capitalize(result) |
|
|
print(f"Generated question: {result}") |
|
|
|
|
|
if not re.search(r'[.?!]$', result): |
|
|
result += '.' |
|
|
|
|
|
return result |
|
|
except Exception as e: |
|
|
return f"Lỗi khi sinh câu hỏi: {e}" |
|
|
|
|
|
|
|
|
|
|
|
def generate_qa_list(context, num_questions, model_choice): |
|
|
doc = nlp(context) |
|
|
entities = list(set([ent.text for ent in doc.ents])) |
|
|
entities = [e for e in entities if len(e.strip().split()) <= 10] |
|
|
|
|
|
if not entities: |
|
|
return gr.update(visible=True), ["❌ Không tìm thấy thực thể nào để sinh câu hỏi."] |
|
|
|
|
|
count = min(num_questions, len(entities)) |
|
|
qa_list = [] |
|
|
|
|
|
for i in range(count): |
|
|
answer = entities[i] |
|
|
question = generate_question(context, answer, model_choice) |
|
|
answer = smart_capitalize(entities[i]) |
|
|
qa = f"**{question}**\n<details><summary>Hiện câu trả lời</summary><p>{answer}</p></details>" |
|
|
qa_list.append(qa) |
|
|
|
|
|
return gr.update(visible=True), qa_list |
|
|
|
|
|
|
|
|
def analyze_context(context): |
|
|
doc = nlp(context) |
|
|
entities = list(set([ent.text for ent in doc.ents])) |
|
|
entities = [e for e in entities if len(e.strip().split()) <= 10] |
|
|
entity_count = len(entities) |
|
|
|
|
|
if entity_count == 0: |
|
|
return ( |
|
|
gr.update(visible=True, value="❌ Không tìm thấy thực thể nào để sinh câu hỏi."), |
|
|
gr.update(visible=False), |
|
|
gr.update(visible=False), |
|
|
gr.update(visible=False) |
|
|
) |
|
|
else: |
|
|
return ( |
|
|
gr.update(visible=False), |
|
|
gr.update(visible=True, maximum=entity_count, value=min(3, entity_count), label=f"Số câu hỏi (Tối đa: {entity_count})"), |
|
|
gr.update(visible=True), |
|
|
gr.update(visible=True) |
|
|
) |
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
gr.Markdown("## Hệ thống sinh câu hỏi từ Context bằng ProphetNet + spaCy NER") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=4): |
|
|
context_input = gr.Textbox(label="Nhập Context", lines=15, placeholder="Nhập đoạn văn bản...") |
|
|
elapsed_time_md = gr.Markdown(visible=False) |
|
|
with gr.Column(scale=1): |
|
|
model_choice = gr.Dropdown( |
|
|
label="Chọn mô hình", |
|
|
choices=list(MODEL_PATHS.keys()), |
|
|
value="prophetnet1" |
|
|
) |
|
|
num_input = gr.Slider(label="Số câu hỏi", minimum=1, maximum=5, value=3, step=1, visible=False) |
|
|
generate_btn = gr.Button("Sinh câu hỏi", visible=False) |
|
|
|
|
|
|
|
|
status_message = gr.Markdown(visible=False) |
|
|
|
|
|
|
|
|
with gr.Column(visible=False) as output_container: |
|
|
result_md_list = [gr.Markdown(visible=False) for _ in range(5)] |
|
|
|
|
|
|
|
|
def run_generation(context, num_questions, model_choice): |
|
|
start_time = time.time() |
|
|
visible_container, qa_list = generate_qa_list(context, num_questions, model_choice) |
|
|
status_hide = gr.update(visible=False) |
|
|
updates = [] |
|
|
|
|
|
for i in range(5): |
|
|
if i < len(qa_list): |
|
|
updates.append(gr.update(value=qa_list[i], visible=True)) |
|
|
else: |
|
|
updates.append(gr.update(visible=False)) |
|
|
|
|
|
elapsed = time.time() - start_time |
|
|
elapsed_msg = f"⏱️ Thời gian xử lý: {elapsed:.2f} giây" |
|
|
elapsed_md = gr.update(value=elapsed_msg, visible=True) |
|
|
|
|
|
return [status_hide, visible_container, elapsed_md] + updates |
|
|
|
|
|
|
|
|
context_input.change( |
|
|
fn=analyze_context, |
|
|
inputs=[context_input], |
|
|
outputs=[status_message, num_input, generate_btn, elapsed_time_md] |
|
|
) |
|
|
|
|
|
def show_processing(): |
|
|
return gr.update(value="⏳ Đang xử lý...", visible=True) |
|
|
|
|
|
generate_btn.click( |
|
|
fn=show_processing, |
|
|
inputs=[], |
|
|
outputs=[status_message] |
|
|
).then( |
|
|
fn=run_generation, |
|
|
inputs=[context_input, num_input, model_choice], |
|
|
outputs=[status_message, output_container, elapsed_time_md] + result_md_list |
|
|
) |
|
|
|
|
|
demo.launch() |
|
|
|
|
|
|