import gradio as gr import spacy from transformers import ProphetNetTokenizer, ProphetNetForConditionalGeneration, pipeline, T5Tokenizer, T5ForConditionalGeneration import torch import time import re import os # Tải mô hình spaCy if not spacy.util.is_package("en_core_web_md"): print("Đang tải mô hình spaCy 'en_core_web_md'...") spacy.cli.download("en_core_web_md") # <--- Lỗi xảy ra ở đây nlp = spacy.load("en_core_web_md") print("✅ Đã tải/nạp mô hình spaCy.") MODEL_PATHS = { "prophetnet2": "ManB2207540/prophetnet_SQuAD_1.1-2epoch_break", "prophetnet tieu chuan": "microsoft/prophetnet-large-uncased-squad-qg", "t5-small-finetuned": "tbtminh/t5-small-qg-finetuned" } def load_pipeline(model_path): tokenizer = ProphetNetTokenizer.from_pretrained(model_path) model = ProphetNetForConditionalGeneration.from_pretrained(model_path) return pipeline( "text2text-generation", model=model, tokenizer=tokenizer, max_length=256, num_return_sequences=1, device=0 if torch.cuda.is_available() else -1 ) def load_t5_pipeline(model_path): tokenizer = T5Tokenizer.from_pretrained(model_path) model = T5ForConditionalGeneration.from_pretrained(model_path) return pipeline( "text2text-generation", model=model, tokenizer=tokenizer, max_length=256, num_return_sequences=1, device=0 if torch.cuda.is_available() else -1 ) pipeline_cache = {} def get_pipeline(model_name): model_path = MODEL_PATHS[model_name] if model_name not in pipeline_cache: if model_name == "t5-small-finetuned": pipeline_cache[model_name] = load_t5_pipeline(model_path) else: pipeline_cache[model_name] = load_pipeline(model_path) return pipeline_cache[model_name] # Tự viết hàm capitalize thông minh def smart_capitalize(text): # Giữ nguyên cách viết hoa phần còn lại, chỉ viết hoa chữ đầu nếu cần text = text.strip() if not text: return text text = text[0].upper() + text[1:] if not re.search(r'[.?!]$', text): text += '.' return text def generate_question(context, answer, model_name): pipe = get_pipeline(model_name) tokenizer = pipe.tokenizer if model_name == "t5-small-finetuned": prompt = f"generate question: context: {context} answer: {answer}" else: prompt = f"context: {context} answer: {answer}" # Cắt prompt nếu vượt quá giới hạn token encoded = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512) input_ids = encoded["input_ids"] attention_mask = encoded["attention_mask"] try: output = pipe.model.generate( input_ids=input_ids.to(pipe.model.device), attention_mask=attention_mask.to(pipe.model.device), max_length=64, num_return_sequences=1, num_beams=4 ) result = pipe.tokenizer.decode(output[0], skip_special_tokens=True).strip() result = smart_capitalize(result) print(f"Generated question: {result}") # Thêm dấu chấm nếu chưa có (và không kết thúc bằng ! hay ?) if not re.search(r'[.?!]$', result): result += '.' return result except Exception as e: return f"Lỗi khi sinh câu hỏi: {e}" def generate_qa_list(context, num_questions, model_choice): doc = nlp(context) entities = list(set([ent.text for ent in doc.ents])) entities = [e for e in entities if len(e.strip().split()) <= 10] if not entities: return gr.update(visible=True), ["❌ Không tìm thấy thực thể nào để sinh câu hỏi."] count = min(num_questions, len(entities)) qa_list = [] for i in range(count): answer = entities[i] question = generate_question(context, answer, model_choice) answer = smart_capitalize(entities[i]) qa = f"**{question}**\n
Hiện câu trả lời

{answer}

" qa_list.append(qa) return gr.update(visible=True), qa_list # Tách phần phân tích context và cập nhật slider def analyze_context(context): doc = nlp(context) entities = list(set([ent.text for ent in doc.ents])) entities = [e for e in entities if len(e.strip().split()) <= 10] entity_count = len(entities) if entity_count == 0: return ( gr.update(visible=True, value="❌ Không tìm thấy thực thể nào để sinh câu hỏi."), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False) ) else: return ( gr.update(visible=False), gr.update(visible=True, maximum=entity_count, value=min(3, entity_count), label=f"Số câu hỏi (Tối đa: {entity_count})"), gr.update(visible=True), gr.update(visible=True) ) with gr.Blocks() as demo: gr.Markdown("## Hệ thống sinh câu hỏi từ Context bằng Seq2Seq Transformer + spaCy NER") with gr.Row(): with gr.Column(scale=4): context_input = gr.Textbox(label="Nhập Context", lines=15, placeholder="Nhập đoạn văn bản...") elapsed_time_md = gr.Markdown(visible=False) with gr.Column(scale=1): model_choice = gr.Dropdown( label="Chọn mô hình", choices=list(MODEL_PATHS.keys()), value="prophetnet2" ) num_input = gr.Slider(label="Số câu hỏi", minimum=1, maximum=5, value=3, step=1, visible=False) generate_btn = gr.Button("Sinh câu hỏi", visible=False) # Thông báo đang xử lý hoặc không tìm thấy status_message = gr.Markdown(visible=False) # Kết quả hiển thị tại đây with gr.Column(visible=False) as output_container: result_md_list = [gr.Markdown(visible=False) for _ in range(5)] # Xử lý khi bấm nút sinh câu hỏi def run_generation(context, num_questions, model_choice): start_time = time.time() visible_container, qa_list = generate_qa_list(context, num_questions, model_choice) status_hide = gr.update(visible=False) updates = [] for i in range(5): if i < len(qa_list): updates.append(gr.update(value=qa_list[i], visible=True)) else: updates.append(gr.update(visible=False)) elapsed = time.time() - start_time elapsed_msg = f"⏱️ Thời gian xử lý: {elapsed:.2f} giây" elapsed_md = gr.update(value=elapsed_msg, visible=True) return [status_hide, visible_container, elapsed_md] + updates # Khi người dùng thay đổi context, tự động phân tích thực thể và cập nhật slider context_input.change( fn=analyze_context, inputs=[context_input], outputs=[status_message, num_input, generate_btn, elapsed_time_md] ) def show_processing(): return gr.update(value="⏳ Đang xử lý...", visible=True) generate_btn.click( fn=show_processing, inputs=[], outputs=[status_message] ).then( fn=run_generation, inputs=[context_input, num_input, model_choice], outputs=[status_message, output_container, elapsed_time_md] + result_md_list ) demo.launch(share=True) # #/Users/trantieuman/anaconda3/bin/python /Users/trantieuman/Documents/NLP/project/demo.py