File size: 11,070 Bytes
c750faa
 
 
a291170
e9c022b
c750faa
 
 
a291170
c750faa
f7f1029
 
 
a291170
c750faa
a291170
 
c750faa
9b3b23d
a291170
 
c750faa
 
a291170
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e9c022b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a291170
 
 
 
e9c022b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a291170
e9c022b
a291170
 
 
 
 
 
 
 
 
 
 
c750faa
a291170
c750faa
 
 
 
 
a291170
 
e9c022b
 
 
 
a291170
 
c750faa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a291170
 
e9c022b
 
 
 
a291170
 
 
 
 
 
 
 
e9c022b
c750faa
a291170
c750faa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a291170
c750faa
a291170
c750faa
a291170
c750faa
 
 
 
 
 
a291170
 
 
c750faa
 
 
 
a291170
c750faa
 
a291170
c750faa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a291170
c750faa
 
 
 
 
5da2849
c750faa
 
 
e9c022b
c750faa
 
 
 
 
a291170
c750faa
a291170
159439d
c750faa
 
 
 
 
 
a291170
c750faa
 
 
 
a291170
 
 
 
 
c750faa
a291170
 
c750faa
 
 
 
 
 
 
 
 
a291170
c750faa
 
 
 
a291170
c750faa
 
 
 
 
 
bb1c3e8
c750faa
 
 
 
 
 
 
 
 
 
a291170
 
c750faa
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285

import gradio as gr
import spacy
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline, AutoConfig
from transformers import ProphetNetTokenizer, ProphetNetForConditionalGeneration, T5Tokenizer, T5ForConditionalGeneration, BartTokenizer, BartForConditionalGeneration
import torch
import time
import re
import os 

# Tải mô hình spaCy
if not spacy.util.is_package("en_core_web_md"):
    print("Đang tải mô hình spaCy 'en_core_web_md'...")
    spacy.cli.download("en_core_web_md") 
nlp = spacy.load("en_core_web_md")

# Đường dẫn mô hình
MODEL_PATHS = {
    "prophetnet-large-uncased-finetuned": "ManB2207540/prophetnet_large_uncased_qg_squad_2epoch_finetuned",
    "bart-finetuned": "mghan3624/bart_qg_finetune_squad",
    "t5-small-finetuned": "tbtminh/t5-small-qg-finetuned"
}

# Hàm tải mô hình T5
def load_t5_pipeline(model_path):
    try:
        tokenizer = T5Tokenizer.from_pretrained(model_path)
        model = T5ForConditionalGeneration.from_pretrained(model_path)
        return pipeline(
            "text2text-generation",
            model=model,
            tokenizer=tokenizer,
            max_length=256,
            num_return_sequences=1,
            device=0 if torch.cuda.is_available() else -1
        )
    except Exception as e:
        print(f"Failed to load T5 pipeline for {model_path}: {e}")
        return None

# Ham tải mô hình ProphetNet
def load_prophetnet_pipeline(model_path):
    try:
        tokenizer = ProphetNetTokenizer.from_pretrained(model_path)
        model = ProphetNetForConditionalGeneration.from_pretrained(model_path)
        return pipeline(
            "text2text-generation",
            model=model,
            tokenizer=tokenizer,
            max_length=256,
            num_return_sequences=1,
            device=0 if torch.cuda.is_available() else -1
        )
    except Exception as e:
        print(f"Failed to load ProphetNet pipeline for {model_path}: {e}")
        return None

# Hàm tải mô hình Bart
def load_bart_pipeline(model_path):
    try:
        config = AutoConfig.from_pretrained(model_path)
        if getattr(config, "early_stopping", None) is None:
            config.early_stopping = False
        tokenizer = BartTokenizer.from_pretrained(model_path)
        model = BartForConditionalGeneration.from_pretrained(model_path, config=config)
        return pipeline(
            "text2text-generation",
            model=model,
            tokenizer=tokenizer,
            max_length=256,
            num_return_sequences=1,
            device=0 if torch.cuda.is_available() else -1
        )
    except Exception as e:
        print(f"Failed to load Bart pipeline for {model_path}: {e}")
        return None

# Hàm tải mô hình chung
def load_pipeline(model_path):
    try:
        tokenizer = AutoTokenizer.from_pretrained(model_path)
        model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
        return pipeline(
            "text2text-generation",
            model=model,
            tokenizer=tokenizer,
            max_length=256,
            num_return_sequences=1,
            device=0 if torch.cuda.is_available() else -1
        )
    except Exception as e:
        print(f"Failed to load pipeline for {model_path}: {e}")
        return None

# Cache pipeline
pipeline_cache = {}

def get_pipeline(model_name):
    model_path = MODEL_PATHS[model_name]
    if model_name not in pipeline_cache:
        if model_name == "t5-small-finetuned":
            pipeline_cache[model_name] = load_t5_pipeline(model_path)
        elif model_name == "prophetnet-large-uncased-finetuned":
            pipeline_cache[model_name] = load_prophetnet_pipeline(model_path)
        elif model_name == "bart-finetuned":
            pipeline_cache[model_name] = load_bart_pipeline(model_path)
        else:
            pipeline_cache[model_name] = load_pipeline(model_path)
    return pipeline_cache[model_name]

# Tự viết hàm capitalize thông minh


def smart_capitalize(text):
    # Giữ nguyên cách viết hoa phần còn lại, chỉ viết hoa chữ đầu nếu cần
    text = text.strip()
    if not text:
        return text
    text = text[0].upper() + text[1:]
    if not re.search(r'[.?!]$', text):
        text += '.'
    return text

def generate_question(context, answer, model_name):
    pipe = get_pipeline(model_name)
    tokenizer = pipe.tokenizer
    if model_name == "t5-small-finetuned":
        prompt = f"generate question: context: {context} answer: {answer}"
    elif model_name == "prophetnet-large-uncased-finetuned":
        prompt = f"context: {context} answer: {answer}"
    elif model_name == "bart-finetuned":
        prompt = f"context: {context} answer: {answer}"
    else:
        prompt = f"context: {context} answer: {answer}"
    print(f"Prompt: {prompt}")  # In ra prompt để kiểm tra

    # Kiểm tra độ dài của prompt
    encoded = tokenizer(prompt, return_tensors="pt", truncation=False, max_length=512)
    input_ids = encoded["input_ids"]
    if input_ids.size(1) > 512:
        return "❌ Văn bản quá dài. Xin nhập vào văn bản ngắn hơn." # (hơn 512 token)

    # Proceed with tokenization (with truncation if needed)
    encoded = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
    input_ids = encoded["input_ids"]
    attention_mask = encoded["attention_mask"]

    try:
        output = pipe.model.generate(
            input_ids=input_ids.to(pipe.model.device),
            attention_mask=attention_mask.to(pipe.model.device),
            max_length=64,
            num_return_sequences=1,
            num_beams=4
        )
        result = pipe.tokenizer.decode(output[0], skip_special_tokens=True).strip()
        result = smart_capitalize(result)
        print(f"Generated question: {result}")
        # Thêm dấu chấm nếu chưa có (và không kết thúc bằng ! hay ?)
        if not re.search(r'[.?!]$', result):
            result += '.'

        return result
    except Exception as e:
        return f"Lỗi khi sinh câu hỏi: {e}"



def generate_qa_list(context, num_questions, model_choice):
    doc = nlp(context)
    entities = list(set([ent.text for ent in doc.ents]))
    entities = [e for e in entities if len(e.strip().split()) <= 10]

    # Nếu không tìm thấy thực thể, trả về thông báo lỗi trong status_message
    if not entities:
        return gr.update(visible=True, value="❌ Không tìm thấy thực thể nào để sinh câu hỏi."), []

    # Đảm bảo số câu hỏi không vượt quá số thực thể
    count = min(num_questions, len(entities))
    qa_list = []

    for i in range(count):
        answer = entities[i]
        question = generate_question(context, answer, model_choice)
        # Nếu có lỗi (như context quá dài), trả về thông báo lỗi trong status_message
        if question.startswith("❌") or question.startswith("Lỗi"):
            return gr.update(visible=True, value=question), []
        answer = smart_capitalize(entities[i])
        qa = f"**{question}**\n<details><summary>Hiện câu trả lời</summary><p>{answer}</p></details>"
        qa_list.append(qa)

    return gr.update(visible=False), qa_list

# Tách phần phân tích context và cập nhật slider
def analyze_context(context, num_questions):
    doc = nlp(context)
    entities = list(set([ent.text for ent in doc.ents]))
    entities = [e for e in entities if len(e.strip().split()) <= 10]
    entity_count = len(entities)

    if entity_count == 0:
        return (
            gr.update(visible=True, value="❌ Không tìm thấy thực thể nào để sinh câu hỏi."),
            gr.update(visible=False),
            gr.update(visible=False),
            gr.update(visible=False)
        )
    else:
        return (
            gr.update(visible=False),
            gr.update(visible=True, maximum=entity_count, value=min(num_questions, entity_count), label=f"Số câu hỏi (Tối đa: {entity_count})"),
            gr.update(visible=True),
            gr.update(visible=True)
        )

with gr.Blocks() as demo:
    gr.Markdown("## 📘 Hệ thống sinh câu hỏi")

    with gr.Row():
        with gr.Column(scale=4):
            context_input = gr.Textbox(label="Nhập văn bản", lines=15, placeholder="Nhập đoạn văn bản...")
            elapsed_time_md = gr.Markdown(visible=False)
        with gr.Column(scale=1):
            model_choice = gr.Dropdown(
                label="Chọn mô hình",
                choices=list(MODEL_PATHS.keys()),
                value="t5-small-finetuned"
            )
            num_input = gr.Slider(label="Số câu hỏi", minimum=1, maximum=20, value=3, step=1, visible=False)
            generate_btn = gr.Button("🚀 Sinh câu hỏi", visible=False)

    # Thông báo đang xử lý hoặc không tìm thấy
    status_message = gr.Markdown(visible=False)

    # Kết quả hiển thị tại đây
    with gr.Column(visible=False) as output_container:
        result_md_list = [gr.Markdown(visible=False) for _ in range(20)]

    # Xử lý khi bấm nút sinh câu hỏi
    def run_generation(context, num_questions, model_choice):
        start_time = time.time()
        status_message, qa_list = generate_qa_list(context, num_questions, model_choice)
        
        # Nếu có lỗi (status_message visible), trả về ngay lập tức
        if status_message["visible"]:
            return [status_message, gr.update(visible=False), gr.update(visible=False)] + [gr.update(visible=False) for _ in range(20)]

        updates = []
        for i in range(20):
            if i < len(qa_list):
                updates.append(gr.update(value=qa_list[i], visible=True))
            else:
                updates.append(gr.update(visible=False))

        elapsed = time.time() - start_time
        elapsed_msg = f"⏱️ Thời gian xử lý: {elapsed:.2f} giây"
        elapsed_md = gr.update(value=elapsed_msg, visible=True)

        return [gr.update(visible=False), gr.update(visible=True), elapsed_md] + updates

    # Khi người dùng thay đổi context, tự động phân tích thực thể và cập nhật slider
    context_input.change(
        fn=analyze_context,
        inputs=[context_input, num_input],  # Thêm num_input vào inputs
        outputs=[status_message, num_input, generate_btn, elapsed_time_md]
    )

    def show_processing():
        return gr.update(value="⏳ Đang xử lý...", visible=True)

    # Khi người dùng bấm nút sinh câu hỏi, hiển thị thông báo đang xử lý và gọi hàm run_generation
    generate_btn.click(
        fn=show_processing,
        inputs=[],
        outputs=[status_message]
    ).then(
        fn=run_generation,
        inputs=[context_input, num_input, model_choice],
        outputs=[status_message, output_container, elapsed_time_md] + result_md_list
    )

demo.launch()


# #/Users/trantieuman/anaconda3/bin/python /Users/trantieuman/Documents/NLP/project/demo.py