tbtminh's picture
Add t5-small model
4fdd016 verified
raw
history blame
7.61 kB
import gradio as gr
import spacy
from transformers import ProphetNetTokenizer, ProphetNetForConditionalGeneration, pipeline, T5Tokenizer, T5ForConditionalGeneration
import torch
import time
import re
import os
# Tải mô hình spaCy
if not spacy.util.is_package("en_core_web_md"):
print("Đang tải mô hình spaCy 'en_core_web_md'...")
spacy.cli.download("en_core_web_md") # <--- Lỗi xảy ra ở đây
nlp = spacy.load("en_core_web_md")
print("✅ Đã tải/nạp mô hình spaCy.")
MODEL_PATHS = {
"prophetnet2": "ManB2207540/prophetnet_SQuAD_1.1-2epoch_break",
"prophetnet tieu chuan": "microsoft/prophetnet-large-uncased-squad-qg",
"t5-small-finetuned": "tbtminh/t5-small-qg-finetuned"
}
def load_pipeline(model_path):
tokenizer = ProphetNetTokenizer.from_pretrained(model_path)
model = ProphetNetForConditionalGeneration.from_pretrained(model_path)
return pipeline(
"text2text-generation",
model=model,
tokenizer=tokenizer,
max_length=256,
num_return_sequences=1,
device=0 if torch.cuda.is_available() else -1
)
def load_t5_pipeline(model_path):
tokenizer = T5Tokenizer.from_pretrained(model_path)
model = T5ForConditionalGeneration.from_pretrained(model_path)
return pipeline(
"text2text-generation",
model=model,
tokenizer=tokenizer,
max_length=256,
num_return_sequences=1,
device=0 if torch.cuda.is_available() else -1
)
pipeline_cache = {}
def get_pipeline(model_name):
model_path = MODEL_PATHS[model_name]
if model_name not in pipeline_cache:
if model_name == "t5-small-finetuned":
pipeline_cache[model_name] = load_t5_pipeline(model_path)
else:
pipeline_cache[model_name] = load_pipeline(model_path)
return pipeline_cache[model_name]
# Tự viết hàm capitalize thông minh
def smart_capitalize(text):
# Giữ nguyên cách viết hoa phần còn lại, chỉ viết hoa chữ đầu nếu cần
text = text.strip()
if not text:
return text
text = text[0].upper() + text[1:]
if not re.search(r'[.?!]$', text):
text += '.'
return text
def generate_question(context, answer, model_name):
pipe = get_pipeline(model_name)
tokenizer = pipe.tokenizer
if model_name == "t5-small-finetuned":
prompt = f"generate question: context: {context} answer: {answer}"
else:
prompt = f"context: {context} answer: {answer}"
# Cắt prompt nếu vượt quá giới hạn token
encoded = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
input_ids = encoded["input_ids"]
attention_mask = encoded["attention_mask"]
try:
output = pipe.model.generate(
input_ids=input_ids.to(pipe.model.device),
attention_mask=attention_mask.to(pipe.model.device),
max_length=64,
num_return_sequences=1,
num_beams=4
)
result = pipe.tokenizer.decode(output[0], skip_special_tokens=True).strip()
result = smart_capitalize(result)
print(f"Generated question: {result}")
# Thêm dấu chấm nếu chưa có (và không kết thúc bằng ! hay ?)
if not re.search(r'[.?!]$', result):
result += '.'
return result
except Exception as e:
return f"Lỗi khi sinh câu hỏi: {e}"
def generate_qa_list(context, num_questions, model_choice):
doc = nlp(context)
entities = list(set([ent.text for ent in doc.ents]))
entities = [e for e in entities if len(e.strip().split()) <= 10]
if not entities:
return gr.update(visible=True), ["❌ Không tìm thấy thực thể nào để sinh câu hỏi."]
count = min(num_questions, len(entities))
qa_list = []
for i in range(count):
answer = entities[i]
question = generate_question(context, answer, model_choice)
answer = smart_capitalize(entities[i])
qa = f"**{question}**\n<details><summary>Hiện câu trả lời</summary><p>{answer}</p></details>"
qa_list.append(qa)
return gr.update(visible=True), qa_list
# Tách phần phân tích context và cập nhật slider
def analyze_context(context):
doc = nlp(context)
entities = list(set([ent.text for ent in doc.ents]))
entities = [e for e in entities if len(e.strip().split()) <= 10]
entity_count = len(entities)
if entity_count == 0:
return (
gr.update(visible=True, value="❌ Không tìm thấy thực thể nào để sinh câu hỏi."),
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False)
)
else:
return (
gr.update(visible=False),
gr.update(visible=True, maximum=entity_count, value=min(3, entity_count), label=f"Số câu hỏi (Tối đa: {entity_count})"),
gr.update(visible=True),
gr.update(visible=True)
)
with gr.Blocks() as demo:
gr.Markdown("## Hệ thống sinh câu hỏi từ Context bằng Seq2Seq Transformer + spaCy NER")
with gr.Row():
with gr.Column(scale=4):
context_input = gr.Textbox(label="Nhập Context", lines=15, placeholder="Nhập đoạn văn bản...")
elapsed_time_md = gr.Markdown(visible=False)
with gr.Column(scale=1):
model_choice = gr.Dropdown(
label="Chọn mô hình",
choices=list(MODEL_PATHS.keys()),
value="prophetnet2"
)
num_input = gr.Slider(label="Số câu hỏi", minimum=1, maximum=5, value=3, step=1, visible=False)
generate_btn = gr.Button("Sinh câu hỏi", visible=False)
# Thông báo đang xử lý hoặc không tìm thấy
status_message = gr.Markdown(visible=False)
# Kết quả hiển thị tại đây
with gr.Column(visible=False) as output_container:
result_md_list = [gr.Markdown(visible=False) for _ in range(5)]
# Xử lý khi bấm nút sinh câu hỏi
def run_generation(context, num_questions, model_choice):
start_time = time.time()
visible_container, qa_list = generate_qa_list(context, num_questions, model_choice)
status_hide = gr.update(visible=False)
updates = []
for i in range(5):
if i < len(qa_list):
updates.append(gr.update(value=qa_list[i], visible=True))
else:
updates.append(gr.update(visible=False))
elapsed = time.time() - start_time
elapsed_msg = f"⏱️ Thời gian xử lý: {elapsed:.2f} giây"
elapsed_md = gr.update(value=elapsed_msg, visible=True)
return [status_hide, visible_container, elapsed_md] + updates
# Khi người dùng thay đổi context, tự động phân tích thực thể và cập nhật slider
context_input.change(
fn=analyze_context,
inputs=[context_input],
outputs=[status_message, num_input, generate_btn, elapsed_time_md]
)
def show_processing():
return gr.update(value="⏳ Đang xử lý...", visible=True)
generate_btn.click(
fn=show_processing,
inputs=[],
outputs=[status_message]
).then(
fn=run_generation,
inputs=[context_input, num_input, model_choice],
outputs=[status_message, output_container, elapsed_time_md] + result_md_list
)
demo.launch(share=True)
# #/Users/trantieuman/anaconda3/bin/python /Users/trantieuman/Documents/NLP/project/demo.py