import gradio as gr import spacy from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline, AutoConfig from transformers import ProphetNetTokenizer, ProphetNetForConditionalGeneration, T5Tokenizer, T5ForConditionalGeneration, BartTokenizer, BartForConditionalGeneration import torch import time import re import os # Tải mô hình spaCy if not spacy.util.is_package("en_core_web_md"): print("Đang tải mô hình spaCy 'en_core_web_md'...") spacy.cli.download("en_core_web_md") nlp = spacy.load("en_core_web_md") # Đường dẫn mô hình MODEL_PATHS = { "prophetnet-large-uncased-finetuned": "ManB2207540/prophetnet_large_uncased_qg_squad_2epoch_finetuned", "bart-finetuned": "mghan3624/bart_qg_finetune_squad", "t5-small-finetuned": "tbtminh/t5-small-qg-finetuned" } # Hàm tải mô hình T5 def load_t5_pipeline(model_path): try: tokenizer = T5Tokenizer.from_pretrained(model_path) model = T5ForConditionalGeneration.from_pretrained(model_path) return pipeline( "text2text-generation", model=model, tokenizer=tokenizer, max_length=256, num_return_sequences=1, device=0 if torch.cuda.is_available() else -1 ) except Exception as e: print(f"Failed to load T5 pipeline for {model_path}: {e}") return None # Ham tải mô hình ProphetNet def load_prophetnet_pipeline(model_path): try: tokenizer = ProphetNetTokenizer.from_pretrained(model_path) model = ProphetNetForConditionalGeneration.from_pretrained(model_path) return pipeline( "text2text-generation", model=model, tokenizer=tokenizer, max_length=256, num_return_sequences=1, device=0 if torch.cuda.is_available() else -1 ) except Exception as e: print(f"Failed to load ProphetNet pipeline for {model_path}: {e}") return None # Hàm tải mô hình Bart def load_bart_pipeline(model_path): try: config = AutoConfig.from_pretrained(model_path) if getattr(config, "early_stopping", None) is None: config.early_stopping = False tokenizer = BartTokenizer.from_pretrained(model_path) model = BartForConditionalGeneration.from_pretrained(model_path, config=config) return pipeline( "text2text-generation", model=model, tokenizer=tokenizer, max_length=256, num_return_sequences=1, device=0 if torch.cuda.is_available() else -1 ) except Exception as e: print(f"Failed to load Bart pipeline for {model_path}: {e}") return None # Hàm tải mô hình chung def load_pipeline(model_path): try: tokenizer = AutoTokenizer.from_pretrained(model_path) model = AutoModelForSeq2SeqLM.from_pretrained(model_path) return pipeline( "text2text-generation", model=model, tokenizer=tokenizer, max_length=256, num_return_sequences=1, device=0 if torch.cuda.is_available() else -1 ) except Exception as e: print(f"Failed to load pipeline for {model_path}: {e}") return None # Cache pipeline pipeline_cache = {} def get_pipeline(model_name): model_path = MODEL_PATHS[model_name] if model_name not in pipeline_cache: if model_name == "t5-small-finetuned": pipeline_cache[model_name] = load_t5_pipeline(model_path) elif model_name == "prophetnet-large-uncased-finetuned": pipeline_cache[model_name] = load_prophetnet_pipeline(model_path) elif model_name == "bart-finetuned": pipeline_cache[model_name] = load_bart_pipeline(model_path) else: pipeline_cache[model_name] = load_pipeline(model_path) return pipeline_cache[model_name] # Tự viết hàm capitalize thông minh def smart_capitalize(text): # Giữ nguyên cách viết hoa phần còn lại, chỉ viết hoa chữ đầu nếu cần text = text.strip() if not text: return text text = text[0].upper() + text[1:] if not re.search(r'[.?!]$', text): text += '.' return text def generate_question(context, answer, model_name): pipe = get_pipeline(model_name) tokenizer = pipe.tokenizer if model_name == "t5-small-finetuned": prompt = f"generate question: context: {context} answer: {answer}" elif model_name == "prophetnet-large-uncased-finetuned": prompt = f"context: {context} answer: {answer}" elif model_name == "bart-finetuned": prompt = f"context: {context} answer: {answer}" else: prompt = f"context: {context} answer: {answer}" print(f"Prompt: {prompt}") # In ra prompt để kiểm tra # Kiểm tra độ dài của prompt encoded = tokenizer(prompt, return_tensors="pt", truncation=False, max_length=512) input_ids = encoded["input_ids"] if input_ids.size(1) > 512: return "❌ Văn bản quá dài. Xin nhập vào văn bản ngắn hơn." # (hơn 512 token) # Proceed with tokenization (with truncation if needed) encoded = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512) input_ids = encoded["input_ids"] attention_mask = encoded["attention_mask"] try: output = pipe.model.generate( input_ids=input_ids.to(pipe.model.device), attention_mask=attention_mask.to(pipe.model.device), max_length=64, num_return_sequences=1, num_beams=4 ) result = pipe.tokenizer.decode(output[0], skip_special_tokens=True).strip() result = smart_capitalize(result) print(f"Generated question: {result}") # Thêm dấu chấm nếu chưa có (và không kết thúc bằng ! hay ?) if not re.search(r'[.?!]$', result): result += '.' return result except Exception as e: return f"Lỗi khi sinh câu hỏi: {e}" def generate_qa_list(context, num_questions, model_choice): doc = nlp(context) entities = list(set([ent.text for ent in doc.ents])) entities = [e for e in entities if len(e.strip().split()) <= 10] # Nếu không tìm thấy thực thể, trả về thông báo lỗi trong status_message if not entities: return gr.update(visible=True, value="❌ Không tìm thấy thực thể nào để sinh câu hỏi."), [] # Đảm bảo số câu hỏi không vượt quá số thực thể count = min(num_questions, len(entities)) qa_list = [] for i in range(count): answer = entities[i] question = generate_question(context, answer, model_choice) # Nếu có lỗi (như context quá dài), trả về thông báo lỗi trong status_message if question.startswith("❌") or question.startswith("Lỗi"): return gr.update(visible=True, value=question), [] answer = smart_capitalize(entities[i]) qa = f"**{question}**\n
Hiện câu trả lời

{answer}

" qa_list.append(qa) return gr.update(visible=False), qa_list # Tách phần phân tích context và cập nhật slider def analyze_context(context, num_questions): doc = nlp(context) entities = list(set([ent.text for ent in doc.ents])) entities = [e for e in entities if len(e.strip().split()) <= 10] entity_count = len(entities) if entity_count == 0: return ( gr.update(visible=True, value="❌ Không tìm thấy thực thể nào để sinh câu hỏi."), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False) ) else: return ( gr.update(visible=False), gr.update(visible=True, maximum=entity_count, value=min(num_questions, entity_count), label=f"Số câu hỏi (Tối đa: {entity_count})"), gr.update(visible=True), gr.update(visible=True) ) with gr.Blocks() as demo: gr.Markdown("## 📘 Hệ thống sinh câu hỏi") with gr.Row(): with gr.Column(scale=4): context_input = gr.Textbox(label="Nhập văn bản", lines=15, placeholder="Nhập đoạn văn bản...") elapsed_time_md = gr.Markdown(visible=False) with gr.Column(scale=1): model_choice = gr.Dropdown( label="Chọn mô hình", choices=list(MODEL_PATHS.keys()), value="t5-small-finetuned" ) num_input = gr.Slider(label="Số câu hỏi", minimum=1, maximum=20, value=3, step=1, visible=False) generate_btn = gr.Button("🚀 Sinh câu hỏi", visible=False) # Thông báo đang xử lý hoặc không tìm thấy status_message = gr.Markdown(visible=False) # Kết quả hiển thị tại đây with gr.Column(visible=False) as output_container: result_md_list = [gr.Markdown(visible=False) for _ in range(20)] # Xử lý khi bấm nút sinh câu hỏi def run_generation(context, num_questions, model_choice): start_time = time.time() status_message, qa_list = generate_qa_list(context, num_questions, model_choice) # Nếu có lỗi (status_message visible), trả về ngay lập tức if status_message["visible"]: return [status_message, gr.update(visible=False), gr.update(visible=False)] + [gr.update(visible=False) for _ in range(20)] updates = [] for i in range(20): if i < len(qa_list): updates.append(gr.update(value=qa_list[i], visible=True)) else: updates.append(gr.update(visible=False)) elapsed = time.time() - start_time elapsed_msg = f"⏱️ Thời gian xử lý: {elapsed:.2f} giây" elapsed_md = gr.update(value=elapsed_msg, visible=True) return [gr.update(visible=False), gr.update(visible=True), elapsed_md] + updates # Khi người dùng thay đổi context, tự động phân tích thực thể và cập nhật slider context_input.change( fn=analyze_context, inputs=[context_input, num_input], # Thêm num_input vào inputs outputs=[status_message, num_input, generate_btn, elapsed_time_md] ) def show_processing(): return gr.update(value="⏳ Đang xử lý...", visible=True) # Khi người dùng bấm nút sinh câu hỏi, hiển thị thông báo đang xử lý và gọi hàm run_generation generate_btn.click( fn=show_processing, inputs=[], outputs=[status_message] ).then( fn=run_generation, inputs=[context_input, num_input, model_choice], outputs=[status_message, output_container, elapsed_time_md] + result_md_list ) demo.launch() # #/Users/trantieuman/anaconda3/bin/python /Users/trantieuman/Documents/NLP/project/demo.py