File size: 11,070 Bytes
c750faa a291170 e9c022b c750faa a291170 c750faa f7f1029 a291170 c750faa a291170 c750faa 9b3b23d a291170 c750faa a291170 e9c022b a291170 e9c022b a291170 e9c022b a291170 c750faa a291170 c750faa a291170 e9c022b a291170 c750faa a291170 e9c022b a291170 e9c022b c750faa a291170 c750faa a291170 c750faa a291170 c750faa a291170 c750faa a291170 c750faa a291170 c750faa a291170 c750faa a291170 c750faa 5da2849 c750faa e9c022b c750faa a291170 c750faa a291170 159439d c750faa a291170 c750faa a291170 c750faa a291170 c750faa a291170 c750faa a291170 c750faa bb1c3e8 c750faa a291170 c750faa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 |
import gradio as gr
import spacy
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline, AutoConfig
from transformers import ProphetNetTokenizer, ProphetNetForConditionalGeneration, T5Tokenizer, T5ForConditionalGeneration, BartTokenizer, BartForConditionalGeneration
import torch
import time
import re
import os
# Tải mô hình spaCy
if not spacy.util.is_package("en_core_web_md"):
print("Đang tải mô hình spaCy 'en_core_web_md'...")
spacy.cli.download("en_core_web_md")
nlp = spacy.load("en_core_web_md")
# Đường dẫn mô hình
MODEL_PATHS = {
"prophetnet-large-uncased-finetuned": "ManB2207540/prophetnet_large_uncased_qg_squad_2epoch_finetuned",
"bart-finetuned": "mghan3624/bart_qg_finetune_squad",
"t5-small-finetuned": "tbtminh/t5-small-qg-finetuned"
}
# Hàm tải mô hình T5
def load_t5_pipeline(model_path):
try:
tokenizer = T5Tokenizer.from_pretrained(model_path)
model = T5ForConditionalGeneration.from_pretrained(model_path)
return pipeline(
"text2text-generation",
model=model,
tokenizer=tokenizer,
max_length=256,
num_return_sequences=1,
device=0 if torch.cuda.is_available() else -1
)
except Exception as e:
print(f"Failed to load T5 pipeline for {model_path}: {e}")
return None
# Ham tải mô hình ProphetNet
def load_prophetnet_pipeline(model_path):
try:
tokenizer = ProphetNetTokenizer.from_pretrained(model_path)
model = ProphetNetForConditionalGeneration.from_pretrained(model_path)
return pipeline(
"text2text-generation",
model=model,
tokenizer=tokenizer,
max_length=256,
num_return_sequences=1,
device=0 if torch.cuda.is_available() else -1
)
except Exception as e:
print(f"Failed to load ProphetNet pipeline for {model_path}: {e}")
return None
# Hàm tải mô hình Bart
def load_bart_pipeline(model_path):
try:
config = AutoConfig.from_pretrained(model_path)
if getattr(config, "early_stopping", None) is None:
config.early_stopping = False
tokenizer = BartTokenizer.from_pretrained(model_path)
model = BartForConditionalGeneration.from_pretrained(model_path, config=config)
return pipeline(
"text2text-generation",
model=model,
tokenizer=tokenizer,
max_length=256,
num_return_sequences=1,
device=0 if torch.cuda.is_available() else -1
)
except Exception as e:
print(f"Failed to load Bart pipeline for {model_path}: {e}")
return None
# Hàm tải mô hình chung
def load_pipeline(model_path):
try:
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
return pipeline(
"text2text-generation",
model=model,
tokenizer=tokenizer,
max_length=256,
num_return_sequences=1,
device=0 if torch.cuda.is_available() else -1
)
except Exception as e:
print(f"Failed to load pipeline for {model_path}: {e}")
return None
# Cache pipeline
pipeline_cache = {}
def get_pipeline(model_name):
model_path = MODEL_PATHS[model_name]
if model_name not in pipeline_cache:
if model_name == "t5-small-finetuned":
pipeline_cache[model_name] = load_t5_pipeline(model_path)
elif model_name == "prophetnet-large-uncased-finetuned":
pipeline_cache[model_name] = load_prophetnet_pipeline(model_path)
elif model_name == "bart-finetuned":
pipeline_cache[model_name] = load_bart_pipeline(model_path)
else:
pipeline_cache[model_name] = load_pipeline(model_path)
return pipeline_cache[model_name]
# Tự viết hàm capitalize thông minh
def smart_capitalize(text):
# Giữ nguyên cách viết hoa phần còn lại, chỉ viết hoa chữ đầu nếu cần
text = text.strip()
if not text:
return text
text = text[0].upper() + text[1:]
if not re.search(r'[.?!]$', text):
text += '.'
return text
def generate_question(context, answer, model_name):
pipe = get_pipeline(model_name)
tokenizer = pipe.tokenizer
if model_name == "t5-small-finetuned":
prompt = f"generate question: context: {context} answer: {answer}"
elif model_name == "prophetnet-large-uncased-finetuned":
prompt = f"context: {context} answer: {answer}"
elif model_name == "bart-finetuned":
prompt = f"context: {context} answer: {answer}"
else:
prompt = f"context: {context} answer: {answer}"
print(f"Prompt: {prompt}") # In ra prompt để kiểm tra
# Kiểm tra độ dài của prompt
encoded = tokenizer(prompt, return_tensors="pt", truncation=False, max_length=512)
input_ids = encoded["input_ids"]
if input_ids.size(1) > 512:
return "❌ Văn bản quá dài. Xin nhập vào văn bản ngắn hơn." # (hơn 512 token)
# Proceed with tokenization (with truncation if needed)
encoded = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
input_ids = encoded["input_ids"]
attention_mask = encoded["attention_mask"]
try:
output = pipe.model.generate(
input_ids=input_ids.to(pipe.model.device),
attention_mask=attention_mask.to(pipe.model.device),
max_length=64,
num_return_sequences=1,
num_beams=4
)
result = pipe.tokenizer.decode(output[0], skip_special_tokens=True).strip()
result = smart_capitalize(result)
print(f"Generated question: {result}")
# Thêm dấu chấm nếu chưa có (và không kết thúc bằng ! hay ?)
if not re.search(r'[.?!]$', result):
result += '.'
return result
except Exception as e:
return f"Lỗi khi sinh câu hỏi: {e}"
def generate_qa_list(context, num_questions, model_choice):
doc = nlp(context)
entities = list(set([ent.text for ent in doc.ents]))
entities = [e for e in entities if len(e.strip().split()) <= 10]
# Nếu không tìm thấy thực thể, trả về thông báo lỗi trong status_message
if not entities:
return gr.update(visible=True, value="❌ Không tìm thấy thực thể nào để sinh câu hỏi."), []
# Đảm bảo số câu hỏi không vượt quá số thực thể
count = min(num_questions, len(entities))
qa_list = []
for i in range(count):
answer = entities[i]
question = generate_question(context, answer, model_choice)
# Nếu có lỗi (như context quá dài), trả về thông báo lỗi trong status_message
if question.startswith("❌") or question.startswith("Lỗi"):
return gr.update(visible=True, value=question), []
answer = smart_capitalize(entities[i])
qa = f"**{question}**\n<details><summary>Hiện câu trả lời</summary><p>{answer}</p></details>"
qa_list.append(qa)
return gr.update(visible=False), qa_list
# Tách phần phân tích context và cập nhật slider
def analyze_context(context, num_questions):
doc = nlp(context)
entities = list(set([ent.text for ent in doc.ents]))
entities = [e for e in entities if len(e.strip().split()) <= 10]
entity_count = len(entities)
if entity_count == 0:
return (
gr.update(visible=True, value="❌ Không tìm thấy thực thể nào để sinh câu hỏi."),
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False)
)
else:
return (
gr.update(visible=False),
gr.update(visible=True, maximum=entity_count, value=min(num_questions, entity_count), label=f"Số câu hỏi (Tối đa: {entity_count})"),
gr.update(visible=True),
gr.update(visible=True)
)
with gr.Blocks() as demo:
gr.Markdown("## 📘 Hệ thống sinh câu hỏi")
with gr.Row():
with gr.Column(scale=4):
context_input = gr.Textbox(label="Nhập văn bản", lines=15, placeholder="Nhập đoạn văn bản...")
elapsed_time_md = gr.Markdown(visible=False)
with gr.Column(scale=1):
model_choice = gr.Dropdown(
label="Chọn mô hình",
choices=list(MODEL_PATHS.keys()),
value="t5-small-finetuned"
)
num_input = gr.Slider(label="Số câu hỏi", minimum=1, maximum=20, value=3, step=1, visible=False)
generate_btn = gr.Button("🚀 Sinh câu hỏi", visible=False)
# Thông báo đang xử lý hoặc không tìm thấy
status_message = gr.Markdown(visible=False)
# Kết quả hiển thị tại đây
with gr.Column(visible=False) as output_container:
result_md_list = [gr.Markdown(visible=False) for _ in range(20)]
# Xử lý khi bấm nút sinh câu hỏi
def run_generation(context, num_questions, model_choice):
start_time = time.time()
status_message, qa_list = generate_qa_list(context, num_questions, model_choice)
# Nếu có lỗi (status_message visible), trả về ngay lập tức
if status_message["visible"]:
return [status_message, gr.update(visible=False), gr.update(visible=False)] + [gr.update(visible=False) for _ in range(20)]
updates = []
for i in range(20):
if i < len(qa_list):
updates.append(gr.update(value=qa_list[i], visible=True))
else:
updates.append(gr.update(visible=False))
elapsed = time.time() - start_time
elapsed_msg = f"⏱️ Thời gian xử lý: {elapsed:.2f} giây"
elapsed_md = gr.update(value=elapsed_msg, visible=True)
return [gr.update(visible=False), gr.update(visible=True), elapsed_md] + updates
# Khi người dùng thay đổi context, tự động phân tích thực thể và cập nhật slider
context_input.change(
fn=analyze_context,
inputs=[context_input, num_input], # Thêm num_input vào inputs
outputs=[status_message, num_input, generate_btn, elapsed_time_md]
)
def show_processing():
return gr.update(value="⏳ Đang xử lý...", visible=True)
# Khi người dùng bấm nút sinh câu hỏi, hiển thị thông báo đang xử lý và gọi hàm run_generation
generate_btn.click(
fn=show_processing,
inputs=[],
outputs=[status_message]
).then(
fn=run_generation,
inputs=[context_input, num_input, model_choice],
outputs=[status_message, output_container, elapsed_time_md] + result_md_list
)
demo.launch()
# #/Users/trantieuman/anaconda3/bin/python /Users/trantieuman/Documents/NLP/project/demo.py |