|
|
|
|
|
import gradio as gr |
|
|
import spacy |
|
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline, AutoConfig |
|
|
from transformers import ProphetNetTokenizer, ProphetNetForConditionalGeneration, T5Tokenizer, T5ForConditionalGeneration, BartTokenizer, BartForConditionalGeneration |
|
|
import torch |
|
|
import time |
|
|
import re |
|
|
import os |
|
|
|
|
|
|
|
|
if not spacy.util.is_package("en_core_web_md"): |
|
|
print("Đang tải mô hình spaCy 'en_core_web_md'...") |
|
|
spacy.cli.download("en_core_web_md") |
|
|
nlp = spacy.load("en_core_web_md") |
|
|
|
|
|
|
|
|
MODEL_PATHS = { |
|
|
"prophetnet-large-uncased-finetuned": "ManB2207540/prophetnet_large_uncased_qg_squad_2epoch_finetuned", |
|
|
"bart-finetuned": "mghan3624/bart_qg_finetune_squad", |
|
|
"t5-small-finetuned": "tbtminh/t5-small-qg-finetuned" |
|
|
} |
|
|
|
|
|
|
|
|
def load_t5_pipeline(model_path): |
|
|
try: |
|
|
tokenizer = T5Tokenizer.from_pretrained(model_path) |
|
|
model = T5ForConditionalGeneration.from_pretrained(model_path) |
|
|
return pipeline( |
|
|
"text2text-generation", |
|
|
model=model, |
|
|
tokenizer=tokenizer, |
|
|
max_length=256, |
|
|
num_return_sequences=1, |
|
|
device=0 if torch.cuda.is_available() else -1 |
|
|
) |
|
|
except Exception as e: |
|
|
print(f"Failed to load T5 pipeline for {model_path}: {e}") |
|
|
return None |
|
|
|
|
|
|
|
|
def load_prophetnet_pipeline(model_path): |
|
|
try: |
|
|
tokenizer = ProphetNetTokenizer.from_pretrained(model_path) |
|
|
model = ProphetNetForConditionalGeneration.from_pretrained(model_path) |
|
|
return pipeline( |
|
|
"text2text-generation", |
|
|
model=model, |
|
|
tokenizer=tokenizer, |
|
|
max_length=256, |
|
|
num_return_sequences=1, |
|
|
device=0 if torch.cuda.is_available() else -1 |
|
|
) |
|
|
except Exception as e: |
|
|
print(f"Failed to load ProphetNet pipeline for {model_path}: {e}") |
|
|
return None |
|
|
|
|
|
|
|
|
def load_bart_pipeline(model_path): |
|
|
try: |
|
|
config = AutoConfig.from_pretrained(model_path) |
|
|
if getattr(config, "early_stopping", None) is None: |
|
|
config.early_stopping = False |
|
|
tokenizer = BartTokenizer.from_pretrained(model_path) |
|
|
model = BartForConditionalGeneration.from_pretrained(model_path, config=config) |
|
|
return pipeline( |
|
|
"text2text-generation", |
|
|
model=model, |
|
|
tokenizer=tokenizer, |
|
|
max_length=256, |
|
|
num_return_sequences=1, |
|
|
device=0 if torch.cuda.is_available() else -1 |
|
|
) |
|
|
except Exception as e: |
|
|
print(f"Failed to load Bart pipeline for {model_path}: {e}") |
|
|
return None |
|
|
|
|
|
|
|
|
def load_pipeline(model_path): |
|
|
try: |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_path) |
|
|
model = AutoModelForSeq2SeqLM.from_pretrained(model_path) |
|
|
return pipeline( |
|
|
"text2text-generation", |
|
|
model=model, |
|
|
tokenizer=tokenizer, |
|
|
max_length=256, |
|
|
num_return_sequences=1, |
|
|
device=0 if torch.cuda.is_available() else -1 |
|
|
) |
|
|
except Exception as e: |
|
|
print(f"Failed to load pipeline for {model_path}: {e}") |
|
|
return None |
|
|
|
|
|
|
|
|
pipeline_cache = {} |
|
|
|
|
|
def get_pipeline(model_name): |
|
|
model_path = MODEL_PATHS[model_name] |
|
|
if model_name not in pipeline_cache: |
|
|
if model_name == "t5-small-finetuned": |
|
|
pipeline_cache[model_name] = load_t5_pipeline(model_path) |
|
|
elif model_name == "prophetnet-large-uncased-finetuned": |
|
|
pipeline_cache[model_name] = load_prophetnet_pipeline(model_path) |
|
|
elif model_name == "bart-finetuned": |
|
|
pipeline_cache[model_name] = load_bart_pipeline(model_path) |
|
|
else: |
|
|
pipeline_cache[model_name] = load_pipeline(model_path) |
|
|
return pipeline_cache[model_name] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def smart_capitalize(text): |
|
|
|
|
|
text = text.strip() |
|
|
if not text: |
|
|
return text |
|
|
text = text[0].upper() + text[1:] |
|
|
if not re.search(r'[.?!]$', text): |
|
|
text += '.' |
|
|
return text |
|
|
|
|
|
def generate_question(context, answer, model_name): |
|
|
pipe = get_pipeline(model_name) |
|
|
tokenizer = pipe.tokenizer |
|
|
if model_name == "t5-small-finetuned": |
|
|
prompt = f"generate question: context: {context} answer: {answer}" |
|
|
elif model_name == "prophetnet-large-uncased-finetuned": |
|
|
prompt = f"context: {context} answer: {answer}" |
|
|
elif model_name == "bart-finetuned": |
|
|
prompt = f"context: {context} answer: {answer}" |
|
|
else: |
|
|
prompt = f"context: {context} answer: {answer}" |
|
|
print(f"Prompt: {prompt}") |
|
|
|
|
|
|
|
|
encoded = tokenizer(prompt, return_tensors="pt", truncation=False, max_length=512) |
|
|
input_ids = encoded["input_ids"] |
|
|
if input_ids.size(1) > 512: |
|
|
return "❌ Văn bản quá dài. Xin nhập vào văn bản ngắn hơn." |
|
|
|
|
|
|
|
|
encoded = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512) |
|
|
input_ids = encoded["input_ids"] |
|
|
attention_mask = encoded["attention_mask"] |
|
|
|
|
|
try: |
|
|
output = pipe.model.generate( |
|
|
input_ids=input_ids.to(pipe.model.device), |
|
|
attention_mask=attention_mask.to(pipe.model.device), |
|
|
max_length=64, |
|
|
num_return_sequences=1, |
|
|
num_beams=4 |
|
|
) |
|
|
result = pipe.tokenizer.decode(output[0], skip_special_tokens=True).strip() |
|
|
result = smart_capitalize(result) |
|
|
print(f"Generated question: {result}") |
|
|
|
|
|
if not re.search(r'[.?!]$', result): |
|
|
result += '.' |
|
|
|
|
|
return result |
|
|
except Exception as e: |
|
|
return f"Lỗi khi sinh câu hỏi: {e}" |
|
|
|
|
|
|
|
|
|
|
|
def generate_qa_list(context, num_questions, model_choice): |
|
|
doc = nlp(context) |
|
|
entities = list(set([ent.text for ent in doc.ents])) |
|
|
entities = [e for e in entities if len(e.strip().split()) <= 10] |
|
|
|
|
|
|
|
|
if not entities: |
|
|
return gr.update(visible=True, value="❌ Không tìm thấy thực thể nào để sinh câu hỏi."), [] |
|
|
|
|
|
|
|
|
count = min(num_questions, len(entities)) |
|
|
qa_list = [] |
|
|
|
|
|
for i in range(count): |
|
|
answer = entities[i] |
|
|
question = generate_question(context, answer, model_choice) |
|
|
|
|
|
if question.startswith("❌") or question.startswith("Lỗi"): |
|
|
return gr.update(visible=True, value=question), [] |
|
|
answer = smart_capitalize(entities[i]) |
|
|
qa = f"**{question}**\n<details><summary>Hiện câu trả lời</summary><p>{answer}</p></details>" |
|
|
qa_list.append(qa) |
|
|
|
|
|
return gr.update(visible=False), qa_list |
|
|
|
|
|
|
|
|
def analyze_context(context, num_questions): |
|
|
doc = nlp(context) |
|
|
entities = list(set([ent.text for ent in doc.ents])) |
|
|
entities = [e for e in entities if len(e.strip().split()) <= 10] |
|
|
entity_count = len(entities) |
|
|
|
|
|
if entity_count == 0: |
|
|
return ( |
|
|
gr.update(visible=True, value="❌ Không tìm thấy thực thể nào để sinh câu hỏi."), |
|
|
gr.update(visible=False), |
|
|
gr.update(visible=False), |
|
|
gr.update(visible=False) |
|
|
) |
|
|
else: |
|
|
return ( |
|
|
gr.update(visible=False), |
|
|
gr.update(visible=True, maximum=entity_count, value=min(num_questions, entity_count), label=f"Số câu hỏi (Tối đa: {entity_count})"), |
|
|
gr.update(visible=True), |
|
|
gr.update(visible=True) |
|
|
) |
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
gr.Markdown("## 📘 Hệ thống sinh câu hỏi") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=4): |
|
|
context_input = gr.Textbox(label="Nhập văn bản", lines=15, placeholder="Nhập đoạn văn bản...") |
|
|
elapsed_time_md = gr.Markdown(visible=False) |
|
|
with gr.Column(scale=1): |
|
|
model_choice = gr.Dropdown( |
|
|
label="Chọn mô hình", |
|
|
choices=list(MODEL_PATHS.keys()), |
|
|
value="t5-small-finetuned" |
|
|
) |
|
|
num_input = gr.Slider(label="Số câu hỏi", minimum=1, maximum=20, value=3, step=1, visible=False) |
|
|
generate_btn = gr.Button("🚀 Sinh câu hỏi", visible=False) |
|
|
|
|
|
|
|
|
status_message = gr.Markdown(visible=False) |
|
|
|
|
|
|
|
|
with gr.Column(visible=False) as output_container: |
|
|
result_md_list = [gr.Markdown(visible=False) for _ in range(20)] |
|
|
|
|
|
|
|
|
def run_generation(context, num_questions, model_choice): |
|
|
start_time = time.time() |
|
|
status_message, qa_list = generate_qa_list(context, num_questions, model_choice) |
|
|
|
|
|
|
|
|
if status_message["visible"]: |
|
|
return [status_message, gr.update(visible=False), gr.update(visible=False)] + [gr.update(visible=False) for _ in range(20)] |
|
|
|
|
|
updates = [] |
|
|
for i in range(20): |
|
|
if i < len(qa_list): |
|
|
updates.append(gr.update(value=qa_list[i], visible=True)) |
|
|
else: |
|
|
updates.append(gr.update(visible=False)) |
|
|
|
|
|
elapsed = time.time() - start_time |
|
|
elapsed_msg = f"⏱️ Thời gian xử lý: {elapsed:.2f} giây" |
|
|
elapsed_md = gr.update(value=elapsed_msg, visible=True) |
|
|
|
|
|
return [gr.update(visible=False), gr.update(visible=True), elapsed_md] + updates |
|
|
|
|
|
|
|
|
context_input.change( |
|
|
fn=analyze_context, |
|
|
inputs=[context_input, num_input], |
|
|
outputs=[status_message, num_input, generate_btn, elapsed_time_md] |
|
|
) |
|
|
|
|
|
def show_processing(): |
|
|
return gr.update(value="⏳ Đang xử lý...", visible=True) |
|
|
|
|
|
|
|
|
generate_btn.click( |
|
|
fn=show_processing, |
|
|
inputs=[], |
|
|
outputs=[status_message] |
|
|
).then( |
|
|
fn=run_generation, |
|
|
inputs=[context_input, num_input, model_choice], |
|
|
outputs=[status_message, output_container, elapsed_time_md] + result_md_list |
|
|
) |
|
|
|
|
|
demo.launch() |
|
|
|
|
|
|
|
|
|