Commit
·
a291170
1
Parent(s):
a6ed7ac
modify app.py for T5 and requirements.txt
Browse files- app.py +86 -34
- requirements.txt +1 -0
app.py
CHANGED
|
@@ -1,41 +1,74 @@
|
|
| 1 |
|
| 2 |
import gradio as gr
|
| 3 |
import spacy
|
| 4 |
-
from transformers import
|
|
|
|
| 5 |
import torch
|
| 6 |
import time
|
| 7 |
import re
|
| 8 |
-
import os
|
| 9 |
|
| 10 |
# Tải mô hình spaCy
|
| 11 |
if not spacy.util.is_package("en_core_web_md"):
|
| 12 |
print("Đang tải mô hình spaCy 'en_core_web_md'...")
|
| 13 |
-
spacy.cli.download("en_core_web_md")
|
| 14 |
nlp = spacy.load("en_core_web_md")
|
| 15 |
-
|
|
|
|
| 16 |
MODEL_PATHS = {
|
| 17 |
-
"
|
| 18 |
-
"prophetnet tieu chuan": "microsoft/prophetnet-large-uncased-squad-qg"
|
|
|
|
|
|
|
| 19 |
}
|
| 20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
def load_pipeline(model_path):
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
|
|
|
|
| 33 |
pipeline_cache = {}
|
| 34 |
|
| 35 |
def get_pipeline(model_name):
|
| 36 |
model_path = MODEL_PATHS[model_name]
|
| 37 |
if model_name not in pipeline_cache:
|
| 38 |
-
|
|
|
|
|
|
|
|
|
|
| 39 |
return pipeline_cache[model_name]
|
| 40 |
|
| 41 |
# Tự viết hàm capitalize thông minh
|
|
@@ -54,9 +87,19 @@ def smart_capitalize(text):
|
|
| 54 |
def generate_question(context, answer, model_name):
|
| 55 |
pipe = get_pipeline(model_name)
|
| 56 |
tokenizer = pipe.tokenizer
|
| 57 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
|
| 59 |
-
#
|
| 60 |
encoded = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
|
| 61 |
input_ids = encoded["input_ids"]
|
| 62 |
attention_mask = encoded["attention_mask"]
|
|
@@ -87,23 +130,28 @@ def generate_qa_list(context, num_questions, model_choice):
|
|
| 87 |
entities = list(set([ent.text for ent in doc.ents]))
|
| 88 |
entities = [e for e in entities if len(e.strip().split()) <= 10]
|
| 89 |
|
|
|
|
| 90 |
if not entities:
|
| 91 |
-
return gr.update(visible=True
|
| 92 |
|
|
|
|
| 93 |
count = min(num_questions, len(entities))
|
| 94 |
qa_list = []
|
| 95 |
|
| 96 |
for i in range(count):
|
| 97 |
answer = entities[i]
|
| 98 |
question = generate_question(context, answer, model_choice)
|
|
|
|
|
|
|
|
|
|
| 99 |
answer = smart_capitalize(entities[i])
|
| 100 |
qa = f"**{question}**\n<details><summary>Hiện câu trả lời</summary><p>{answer}</p></details>"
|
| 101 |
qa_list.append(qa)
|
| 102 |
|
| 103 |
-
return gr.update(visible=
|
| 104 |
|
| 105 |
# Tách phần phân tích context và cập nhật slider
|
| 106 |
-
def analyze_context(context):
|
| 107 |
doc = nlp(context)
|
| 108 |
entities = list(set([ent.text for ent in doc.ents]))
|
| 109 |
entities = [e for e in entities if len(e.strip().split()) <= 10]
|
|
@@ -119,13 +167,13 @@ def analyze_context(context):
|
|
| 119 |
else:
|
| 120 |
return (
|
| 121 |
gr.update(visible=False),
|
| 122 |
-
gr.update(visible=True, maximum=entity_count, value=min(
|
| 123 |
gr.update(visible=True),
|
| 124 |
gr.update(visible=True)
|
| 125 |
)
|
| 126 |
|
| 127 |
with gr.Blocks() as demo:
|
| 128 |
-
gr.Markdown("## Hệ thống sinh câu hỏi từ Context bằng
|
| 129 |
|
| 130 |
with gr.Row():
|
| 131 |
with gr.Column(scale=4):
|
|
@@ -135,9 +183,9 @@ with gr.Blocks() as demo:
|
|
| 135 |
model_choice = gr.Dropdown(
|
| 136 |
label="Chọn mô hình",
|
| 137 |
choices=list(MODEL_PATHS.keys()),
|
| 138 |
-
value="
|
| 139 |
)
|
| 140 |
-
num_input = gr.Slider(label="Số câu hỏi", minimum=1, maximum=
|
| 141 |
generate_btn = gr.Button("Sinh câu hỏi", visible=False)
|
| 142 |
|
| 143 |
# Thông báo đang xử lý hoặc không tìm thấy
|
|
@@ -145,16 +193,19 @@ with gr.Blocks() as demo:
|
|
| 145 |
|
| 146 |
# Kết quả hiển thị tại đây
|
| 147 |
with gr.Column(visible=False) as output_container:
|
| 148 |
-
result_md_list = [gr.Markdown(visible=False) for _ in range(
|
| 149 |
|
| 150 |
# Xử lý khi bấm nút sinh câu hỏi
|
| 151 |
def run_generation(context, num_questions, model_choice):
|
| 152 |
start_time = time.time()
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
|
|
|
|
|
|
| 156 |
|
| 157 |
-
|
|
|
|
| 158 |
if i < len(qa_list):
|
| 159 |
updates.append(gr.update(value=qa_list[i], visible=True))
|
| 160 |
else:
|
|
@@ -164,12 +215,12 @@ with gr.Blocks() as demo:
|
|
| 164 |
elapsed_msg = f"⏱️ Thời gian xử lý: {elapsed:.2f} giây"
|
| 165 |
elapsed_md = gr.update(value=elapsed_msg, visible=True)
|
| 166 |
|
| 167 |
-
return [
|
| 168 |
|
| 169 |
# Khi người dùng thay đổi context, tự động phân tích thực thể và cập nhật slider
|
| 170 |
context_input.change(
|
| 171 |
fn=analyze_context,
|
| 172 |
-
inputs=[context_input],
|
| 173 |
outputs=[status_message, num_input, generate_btn, elapsed_time_md]
|
| 174 |
)
|
| 175 |
|
|
@@ -186,6 +237,7 @@ with gr.Blocks() as demo:
|
|
| 186 |
outputs=[status_message, output_container, elapsed_time_md] + result_md_list
|
| 187 |
)
|
| 188 |
|
| 189 |
-
demo.launch(
|
|
|
|
| 190 |
|
| 191 |
# #/Users/trantieuman/anaconda3/bin/python /Users/trantieuman/Documents/NLP/project/demo.py
|
|
|
|
| 1 |
|
| 2 |
import gradio as gr
|
| 3 |
import spacy
|
| 4 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline, AutoConfig
|
| 5 |
+
from transformers import ProphetNetTokenizer, ProphetNetForConditionalGeneration, T5Tokenizer, T5ForConditionalGeneration
|
| 6 |
import torch
|
| 7 |
import time
|
| 8 |
import re
|
| 9 |
+
import os
|
| 10 |
|
| 11 |
# Tải mô hình spaCy
|
| 12 |
if not spacy.util.is_package("en_core_web_md"):
|
| 13 |
print("Đang tải mô hình spaCy 'en_core_web_md'...")
|
| 14 |
+
spacy.cli.download("en_core_web_md")
|
| 15 |
nlp = spacy.load("en_core_web_md")
|
| 16 |
+
|
| 17 |
+
# Đường dẫn mô hình
|
| 18 |
MODEL_PATHS = {
|
| 19 |
+
"prophetnet-finetuned": "ManB2207540/prophetnet_SQuAD_1.1-2epoch_break",
|
| 20 |
+
"prophetnet tieu chuan": "microsoft/prophetnet-large-uncased-squad-qg",
|
| 21 |
+
"bart-finetuned": "mghan3624/bart_qg_finetune_squad",
|
| 22 |
+
"t5-small-finetuned": "tbtminh/t5-small-qg-finetuned"
|
| 23 |
}
|
| 24 |
|
| 25 |
+
# Hàm tải mô hình T5
|
| 26 |
+
def load_t5_pipeline(model_path):
|
| 27 |
+
try:
|
| 28 |
+
tokenizer = T5Tokenizer.from_pretrained(model_path)
|
| 29 |
+
model = T5ForConditionalGeneration.from_pretrained(model_path)
|
| 30 |
+
return pipeline(
|
| 31 |
+
"text2text-generation",
|
| 32 |
+
model=model,
|
| 33 |
+
tokenizer=tokenizer,
|
| 34 |
+
max_length=256,
|
| 35 |
+
num_return_sequences=1,
|
| 36 |
+
device=0 if torch.cuda.is_available() else -1
|
| 37 |
+
)
|
| 38 |
+
except Exception as e:
|
| 39 |
+
print(f"Failed to load T5 pipeline for {model_path}: {e}")
|
| 40 |
+
return None
|
| 41 |
+
|
| 42 |
+
# Hàm tải mô hình chung
|
| 43 |
def load_pipeline(model_path):
|
| 44 |
+
try:
|
| 45 |
+
config = AutoConfig.from_pretrained(model_path)
|
| 46 |
+
if getattr(config, "early_stopping", None) is None:
|
| 47 |
+
config.early_stopping = False
|
| 48 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
| 49 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(model_path, config=config)
|
| 50 |
+
return pipeline(
|
| 51 |
+
"text2text-generation",
|
| 52 |
+
model=model,
|
| 53 |
+
tokenizer=tokenizer,
|
| 54 |
+
max_length=256,
|
| 55 |
+
num_return_sequences=1,
|
| 56 |
+
device=0 if torch.cuda.is_available() else -1
|
| 57 |
+
)
|
| 58 |
+
except Exception as e:
|
| 59 |
+
print(f"Failed to load pipeline for {model_path}: {e}")
|
| 60 |
+
return None
|
| 61 |
|
| 62 |
+
# Cache pipeline
|
| 63 |
pipeline_cache = {}
|
| 64 |
|
| 65 |
def get_pipeline(model_name):
|
| 66 |
model_path = MODEL_PATHS[model_name]
|
| 67 |
if model_name not in pipeline_cache:
|
| 68 |
+
if model_name == "t5-small-finetuned":
|
| 69 |
+
pipeline_cache[model_name] = load_t5_pipeline(model_path)
|
| 70 |
+
else:
|
| 71 |
+
pipeline_cache[model_name] = load_pipeline(model_path)
|
| 72 |
return pipeline_cache[model_name]
|
| 73 |
|
| 74 |
# Tự viết hàm capitalize thông minh
|
|
|
|
| 87 |
def generate_question(context, answer, model_name):
|
| 88 |
pipe = get_pipeline(model_name)
|
| 89 |
tokenizer = pipe.tokenizer
|
| 90 |
+
if model_name == "t5-small-finetuned":
|
| 91 |
+
prompt = f"generate question: context: {context} answer: {answer}"
|
| 92 |
+
else:
|
| 93 |
+
prompt = f"context: {context} answer: {answer}"
|
| 94 |
+
print(f"Prompt: {prompt}") # In ra prompt để kiểm tra
|
| 95 |
+
|
| 96 |
+
# Kiểm tra độ dài của prompt
|
| 97 |
+
encoded = tokenizer(prompt, return_tensors="pt", truncation=False, max_length=512)
|
| 98 |
+
input_ids = encoded["input_ids"]
|
| 99 |
+
if input_ids.size(1) > 512:
|
| 100 |
+
return "❌ Context quá dài (hơn 512 token). Xin nhập vào context ngắn hơn."
|
| 101 |
|
| 102 |
+
# Proceed with tokenization (with truncation if needed)
|
| 103 |
encoded = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
|
| 104 |
input_ids = encoded["input_ids"]
|
| 105 |
attention_mask = encoded["attention_mask"]
|
|
|
|
| 130 |
entities = list(set([ent.text for ent in doc.ents]))
|
| 131 |
entities = [e for e in entities if len(e.strip().split()) <= 10]
|
| 132 |
|
| 133 |
+
# Nếu không tìm thấy thực thể, trả về thông báo lỗi trong status_message
|
| 134 |
if not entities:
|
| 135 |
+
return gr.update(visible=True, value="❌ Không tìm thấy thực thể nào để sinh câu hỏi."), []
|
| 136 |
|
| 137 |
+
# Đảm bảo số câu hỏi không vượt quá số thực thể
|
| 138 |
count = min(num_questions, len(entities))
|
| 139 |
qa_list = []
|
| 140 |
|
| 141 |
for i in range(count):
|
| 142 |
answer = entities[i]
|
| 143 |
question = generate_question(context, answer, model_choice)
|
| 144 |
+
# Nếu có lỗi (như context quá dài), trả về thông báo lỗi trong status_message
|
| 145 |
+
if question.startswith("❌") or question.startswith("Lỗi"):
|
| 146 |
+
return gr.update(visible=True, value=question), []
|
| 147 |
answer = smart_capitalize(entities[i])
|
| 148 |
qa = f"**{question}**\n<details><summary>Hiện câu trả lời</summary><p>{answer}</p></details>"
|
| 149 |
qa_list.append(qa)
|
| 150 |
|
| 151 |
+
return gr.update(visible=False), qa_list
|
| 152 |
|
| 153 |
# Tách phần phân tích context và cập nhật slider
|
| 154 |
+
def analyze_context(context, num_questions):
|
| 155 |
doc = nlp(context)
|
| 156 |
entities = list(set([ent.text for ent in doc.ents]))
|
| 157 |
entities = [e for e in entities if len(e.strip().split()) <= 10]
|
|
|
|
| 167 |
else:
|
| 168 |
return (
|
| 169 |
gr.update(visible=False),
|
| 170 |
+
gr.update(visible=True, maximum=entity_count, value=min(num_questions, entity_count), label=f"Số câu hỏi (Tối đa: {entity_count})"),
|
| 171 |
gr.update(visible=True),
|
| 172 |
gr.update(visible=True)
|
| 173 |
)
|
| 174 |
|
| 175 |
with gr.Blocks() as demo:
|
| 176 |
+
gr.Markdown("## Hệ thống sinh câu hỏi từ Context bằng ProphetNet + spaCy NER")
|
| 177 |
|
| 178 |
with gr.Row():
|
| 179 |
with gr.Column(scale=4):
|
|
|
|
| 183 |
model_choice = gr.Dropdown(
|
| 184 |
label="Chọn mô hình",
|
| 185 |
choices=list(MODEL_PATHS.keys()),
|
| 186 |
+
value="t5-small-finetuned"
|
| 187 |
)
|
| 188 |
+
num_input = gr.Slider(label="Số câu hỏi", minimum=1, maximum=20, value=3, step=1, visible=False)
|
| 189 |
generate_btn = gr.Button("Sinh câu hỏi", visible=False)
|
| 190 |
|
| 191 |
# Thông báo đang xử lý hoặc không tìm thấy
|
|
|
|
| 193 |
|
| 194 |
# Kết quả hiển thị tại đây
|
| 195 |
with gr.Column(visible=False) as output_container:
|
| 196 |
+
result_md_list = [gr.Markdown(visible=False) for _ in range(20)]
|
| 197 |
|
| 198 |
# Xử lý khi bấm nút sinh câu hỏi
|
| 199 |
def run_generation(context, num_questions, model_choice):
|
| 200 |
start_time = time.time()
|
| 201 |
+
status_message, qa_list = generate_qa_list(context, num_questions, model_choice)
|
| 202 |
+
|
| 203 |
+
# Nếu có lỗi (status_message visible), trả về ngay lập tức
|
| 204 |
+
if status_message["visible"]:
|
| 205 |
+
return [status_message, gr.update(visible=False), gr.update(visible=False)] + [gr.update(visible=False) for _ in range(20)]
|
| 206 |
|
| 207 |
+
updates = []
|
| 208 |
+
for i in range(20):
|
| 209 |
if i < len(qa_list):
|
| 210 |
updates.append(gr.update(value=qa_list[i], visible=True))
|
| 211 |
else:
|
|
|
|
| 215 |
elapsed_msg = f"⏱️ Thời gian xử lý: {elapsed:.2f} giây"
|
| 216 |
elapsed_md = gr.update(value=elapsed_msg, visible=True)
|
| 217 |
|
| 218 |
+
return [gr.update(visible=False), gr.update(visible=True), elapsed_md] + updates
|
| 219 |
|
| 220 |
# Khi người dùng thay đổi context, tự động phân tích thực thể và cập nhật slider
|
| 221 |
context_input.change(
|
| 222 |
fn=analyze_context,
|
| 223 |
+
inputs=[context_input, num_input], # Thêm num_input vào inputs
|
| 224 |
outputs=[status_message, num_input, generate_btn, elapsed_time_md]
|
| 225 |
)
|
| 226 |
|
|
|
|
| 237 |
outputs=[status_message, output_container, elapsed_time_md] + result_md_list
|
| 238 |
)
|
| 239 |
|
| 240 |
+
demo.launch()
|
| 241 |
+
|
| 242 |
|
| 243 |
# #/Users/trantieuman/anaconda3/bin/python /Users/trantieuman/Documents/NLP/project/demo.py
|
requirements.txt
CHANGED
|
@@ -70,3 +70,4 @@ uvicorn==0.35.0
|
|
| 70 |
websockets==15.0.1
|
| 71 |
xxhash==3.5.0
|
| 72 |
yarl==1.20.1
|
|
|
|
|
|
| 70 |
websockets==15.0.1
|
| 71 |
xxhash==3.5.0
|
| 72 |
yarl==1.20.1
|
| 73 |
+
sentencepiece==0.2.0
|