Commit
·
c750faa
1
Parent(s):
e33772f
generate demo
Browse files- README.md +153 -11
- app.py +186 -0
- requirements.txt +70 -0
- test.py +106 -0
README.md
CHANGED
|
@@ -1,14 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
---
|
| 13 |
|
| 14 |
-
|
|
|
|
| 1 |
+
# 📘 Hệ thống sinh câu hỏi từ Context (ProphetNet + spaCy NER)
|
| 2 |
+
|
| 3 |
+
Dự án này triển khai một hệ thống sinh câu hỏi tự động dựa trên một đoạn văn bản (context) và các thực thể được trích xuất từ đó. Nó sử dụng mô hình ProphetNet đã được fine-tuned để sinh câu hỏi và thư viện spaCy cho việc trích xuất thực thể có tên (Named Entity Recognition - NER). Giao diện người dùng được xây dựng bằng Gradio, cho phép tương tác dễ dàng.
|
| 4 |
+
|
| 5 |
+
---
|
| 6 |
+
|
| 7 |
+
## 🚀 Tính năng chính
|
| 8 |
+
|
| 9 |
+
- **Sinh câu hỏi:** Dựa trên một đoạn văn bản (context) và các câu trả lời (entities) được trích xuất.
|
| 10 |
+
- **Trích xuất thực thể (NER):** Sử dụng `en_core_web_md` của spaCy để xác định các thực thể tiềm năng làm câu trả lời.
|
| 11 |
+
- **Hỗ trợ nhiều mô hình:** Cho phép lựa chọn giữa các phiên bản mô hình ProphetNet đã được huấn luyện.
|
| 12 |
+
- **Giao diện web thân thiện:** Được xây dựng bằng Gradio, dễ dàng sử dụng và kiểm tra.
|
| 13 |
+
- **Khả năng tái tạo:** Hướng dẫn chi tiết để bạn có thể cài đặt và chạy dự án này trên máy của mình.
|
| 14 |
+
|
| 15 |
+
---
|
| 16 |
+
|
| 17 |
+
## 🛠️ Yêu cầu hệ thống
|
| 18 |
+
|
| 19 |
+
- Python 3.8+ (nên sử dụng môi trường ảo như Conda hoặc `venv`).
|
| 20 |
+
- Ít nhất 8GB RAM (để tải các mô hình ngôn ngữ lớn).
|
| 21 |
+
- Đề xuất có GPU với VRAM đủ lớn (ví dụ: 8GB+) để có hiệu suất sinh câu hỏi nhanh hơn. Nếu không có GPU, mô hình sẽ chạy trên CPU nhưng có thể chậm hơn đáng kể.
|
| 22 |
+
|
| 23 |
---
|
| 24 |
+
|
| 25 |
+
## 📦 Hướng dẫn cài đặt và chạy dự án
|
| 26 |
+
|
| 27 |
+
Bạn có thể sử dụng `conda` (nếu đã cài Anaconda) hoặc `venv` để tạo môi trường ảo.
|
| 28 |
+
|
| 29 |
+
### Phương pháp 1: Sử dụng Conda (Khuyến nghị)
|
| 30 |
+
|
| 31 |
+
Nếu bạn đã cài đặt **Anaconda** hoặc **Miniconda**:
|
| 32 |
+
|
| 33 |
+
1. **Tạo và kích hoạt môi trường Conda mới:**
|
| 34 |
+
|
| 35 |
+
```bash
|
| 36 |
+
conda create -n qg_env python=3.9
|
| 37 |
+
conda activate qg_env
|
| 38 |
+
```
|
| 39 |
+
|
| 40 |
+
(Bạn có thể chọn phiên bản Python khác như 3.8 hoặc 3.10 nếu muốn, nhưng 3.9 là một lựa chọn tốt).
|
| 41 |
+
|
| 42 |
+
2. **Cài đặt các thư viện từ `requirements.txt`:**
|
| 43 |
+
**Quan trọng:** Bạn cần file `requirements.txt` chứa danh sách các thư viện được sử dụng trong dự án này. Nếu bạn chưa có, hãy tạo nó:
|
| 44 |
+
|
| 45 |
+
```bash
|
| 46 |
+
# Điều hướng đến thư mục gốc của dự án này
|
| 47 |
+
cd path/to/your/project_folder
|
| 48 |
+
# Tạo file requirements.txt
|
| 49 |
+
pip freeze > requirements.txt
|
| 50 |
+
```
|
| 51 |
+
|
| 52 |
+
Sau khi có file `requirements.txt` trong thư mục gốc của dự án, hãy chạy:
|
| 53 |
+
|
| 54 |
+
```bash
|
| 55 |
+
pip install -r requirements.txt
|
| 56 |
+
```
|
| 57 |
+
|
| 58 |
+
3. **Tải mô hình ngôn ngữ `en_core_web_md` của spaCy:**
|
| 59 |
+
|
| 60 |
+
```bash
|
| 61 |
+
python -m spacy download en_core_web_md
|
| 62 |
+
```
|
| 63 |
+
|
| 64 |
+
4. **Tải và đặt các mô hình ProphetNet:**
|
| 65 |
+
Dự án này sử dụng các mô hình ProphetNet đã được fine-tuned. Bạn cần tải chúng về và đặt vào các đường dẫn chính xác như đã khai báo trong code (hoặc chỉnh sửa code cho phù hợp với đường dẫn của bạn).
|
| 66 |
+
Theo code mẫu:
|
| 67 |
+
|
| 68 |
+
- `prophetnet1`: `/Users/trantieuman/Downloads/prophetnet_1epoch/prophetnet_context_to_question_finetuned`
|
| 69 |
+
- `prophetnet2`: `/Users/trantieuman/Downloads/prophetnet_2epoch_final/final_model`
|
| 70 |
+
- (Nếu có) `prophetnet3`: `/path/to/prophetnet_model_3`
|
| 71 |
+
|
| 72 |
+
**Lưu ý:** Để dễ dàng quản lý, bạn nên tạo một thư mục con trong dự án (ví dụ: `models/prophetnet_1epoch_finetuned`) và đặt các mô hình vào đó, sau đó cập nhật `MODEL_PATHS` trong code của bạn thành các đường dẫn tương đối. Ví dụ:
|
| 73 |
+
|
| 74 |
+
```python
|
| 75 |
+
MODEL_PATHS = {
|
| 76 |
+
"prophetnet1": "./models/prophetnet_1epoch_finetuned",
|
| 77 |
+
"prophetnet2": "./models/prophetnet_2epoch_final",
|
| 78 |
+
# ...
|
| 79 |
+
}
|
| 80 |
+
```
|
| 81 |
+
|
| 82 |
+
Đảm bảo các thư mục mô hình này chứa các tệp như `config.json`, `pytorch_model.bin`, `tokenizer_config.json`, `vocab.json`, v.v.
|
| 83 |
+
|
| 84 |
+
### Phương pháp 2: Sử dụng `venv`
|
| 85 |
+
|
| 86 |
+
1. **Tạo và kích hoạt môi trường ảo mới:**
|
| 87 |
+
|
| 88 |
+
```bash
|
| 89 |
+
# Điều hướng đến thư mục gốc của dự án này
|
| 90 |
+
cd path/to/your/project_folder
|
| 91 |
+
|
| 92 |
+
# Tạo môi trường ảo
|
| 93 |
+
python -m venv venv_qg
|
| 94 |
+
|
| 95 |
+
# Kích hoạt môi trường ảo
|
| 96 |
+
source venv_qg/bin/activate # Trên Linux/macOS
|
| 97 |
+
# Hoặc: .\venv_qg\Scripts\activate # Trên Windows
|
| 98 |
+
```
|
| 99 |
+
|
| 100 |
+
2. **Cài đặt các thư viện từ `requirements.txt`:**
|
| 101 |
+
Tương tự như bước 2 của phương pháp Conda, tạo file `requirements.txt` nếu chưa có:
|
| 102 |
+
|
| 103 |
+
```bash
|
| 104 |
+
pip freeze > requirements.txt
|
| 105 |
+
```
|
| 106 |
+
|
| 107 |
+
Sau đó cài đặt:
|
| 108 |
+
|
| 109 |
+
```bash
|
| 110 |
+
pip install -r requirements.txt
|
| 111 |
+
```
|
| 112 |
+
|
| 113 |
+
3. **Tải mô hình ngôn ngữ `en_core_web_md` của spaCy:**
|
| 114 |
+
|
| 115 |
+
```bash
|
| 116 |
+
python -m spacy download en_core_web_md
|
| 117 |
+
```
|
| 118 |
+
|
| 119 |
+
4. **Tải và đặt các mô hình ProphetNet:**
|
| 120 |
+
Tương tự như bước 4 của phương pháp Conda, đảm bảo các file mô hình ProphetNet đã fine-tuned được đặt đúng đường dẫn.
|
| 121 |
+
|
| 122 |
+
---
|
| 123 |
+
|
| 124 |
+
## 🏃 Cách chạy dự án
|
| 125 |
+
|
| 126 |
+
Sau khi đã hoàn thành các bước cài đặt và kích hoạt môi trường ảo:
|
| 127 |
+
|
| 128 |
+
1. **Đảm bảo bạn đang ở trong thư mục gốc của dự án.**
|
| 129 |
+
|
| 130 |
+
2. **Chạy script chính:**
|
| 131 |
+
|
| 132 |
+
```bash
|
| 133 |
+
python demo.py
|
| 134 |
+
```
|
| 135 |
+
|
| 136 |
+
3. **Mở trình duyệt:**
|
| 137 |
+
Khi ứng dụng Gradio khởi chạy, bạn sẽ thấy một URL trong terminal (thường là `http://127.0.0.1:7860` hoặc tương tự). Sao chép URL này và dán vào trình duyệt web của bạn để tương tác với giao diện hệ thống sinh câu hỏi.
|
| 138 |
+
|
| 139 |
+
---
|
| 140 |
+
|
| 141 |
+
## ⚠️ Lưu ý quan trọng
|
| 142 |
+
|
| 143 |
+
- **Đường dẫn mô hình:** Hãy kiểm tra và điều chỉnh các đường dẫn trong biến `MODEL_PATHS` trong code của bạn (`demo.py` hoặc tên file tương ứng) để chúng trỏ đến đúng vị trí các thư mục mô hình ProphetNet đã được tải về trên máy của bạn.
|
| 144 |
+
- **Hiệu suất GPU:** Việc sử dụng GPU sẽ cải thiện đáng kể tốc độ sinh câu hỏi. Đảm bảo cài đặt CUDA và PyTorch với hỗ trợ CUDA nếu bạn muốn tận dụng GPU.
|
| 145 |
+
- **Kiểm tra cache:** Việc sử dụng quá nhiều model co thể gây tràn cache.
|
| 146 |
+
```bash
|
| 147 |
+
du -sh ~/.cache/huggingface/hub
|
| 148 |
+
```
|
| 149 |
+
- **Xoá cache:** Việc sử dụng quá nhiều model co thể gây tràn cache hãy xoá nếu không sử dụng.
|
| 150 |
+
```bash
|
| 151 |
+
huggingface-cli delete-cache
|
| 152 |
+
```
|
| 153 |
+
|
| 154 |
---
|
| 155 |
|
| 156 |
+
Hy vọng hướng dẫn này sẽ giúp bạn và những người khác dễ dàng thiết lập và chạy dự án của mình!
|
app.py
ADDED
|
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import gradio as gr
|
| 3 |
+
import spacy
|
| 4 |
+
from transformers import ProphetNetTokenizer, ProphetNetForConditionalGeneration, pipeline
|
| 5 |
+
import torch
|
| 6 |
+
import time
|
| 7 |
+
import re
|
| 8 |
+
|
| 9 |
+
nlp = spacy.load("en_core_web_md")
|
| 10 |
+
|
| 11 |
+
MODEL_PATHS = {
|
| 12 |
+
"prophetnet2": "ManB2207540/prophetnet_SQuAD_1.1-2epoch_break",
|
| 13 |
+
"prophetnet tieu chuan": "microsoft/prophetnet-large-uncased-squad-qg"
|
| 14 |
+
}
|
| 15 |
+
|
| 16 |
+
def load_pipeline(model_path):
|
| 17 |
+
tokenizer = ProphetNetTokenizer.from_pretrained(model_path)
|
| 18 |
+
model = ProphetNetForConditionalGeneration.from_pretrained(model_path)
|
| 19 |
+
return pipeline(
|
| 20 |
+
"text2text-generation",
|
| 21 |
+
model=model,
|
| 22 |
+
tokenizer=tokenizer,
|
| 23 |
+
max_length=256,
|
| 24 |
+
num_return_sequences=1,
|
| 25 |
+
device=0 if torch.cuda.is_available() else -1
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
pipeline_cache = {}
|
| 29 |
+
|
| 30 |
+
def get_pipeline(model_name):
|
| 31 |
+
model_path = MODEL_PATHS[model_name]
|
| 32 |
+
if model_name not in pipeline_cache:
|
| 33 |
+
pipeline_cache[model_name] = load_pipeline(model_path)
|
| 34 |
+
return pipeline_cache[model_name]
|
| 35 |
+
|
| 36 |
+
# Tự viết hàm capitalize thông minh
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def smart_capitalize(text):
|
| 40 |
+
# Giữ nguyên cách viết hoa phần còn lại, chỉ viết hoa chữ đầu nếu cần
|
| 41 |
+
text = text.strip()
|
| 42 |
+
if not text:
|
| 43 |
+
return text
|
| 44 |
+
text = text[0].upper() + text[1:]
|
| 45 |
+
if not re.search(r'[.?!]$', text):
|
| 46 |
+
text += '.'
|
| 47 |
+
return text
|
| 48 |
+
|
| 49 |
+
def generate_question(context, answer, model_name):
|
| 50 |
+
pipe = get_pipeline(model_name)
|
| 51 |
+
tokenizer = pipe.tokenizer
|
| 52 |
+
prompt = f"context: {context} answer: {answer}"
|
| 53 |
+
|
| 54 |
+
# Cắt prompt nếu vượt quá giới hạn token
|
| 55 |
+
encoded = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
|
| 56 |
+
input_ids = encoded["input_ids"]
|
| 57 |
+
attention_mask = encoded["attention_mask"]
|
| 58 |
+
|
| 59 |
+
try:
|
| 60 |
+
output = pipe.model.generate(
|
| 61 |
+
input_ids=input_ids.to(pipe.model.device),
|
| 62 |
+
attention_mask=attention_mask.to(pipe.model.device),
|
| 63 |
+
max_length=64,
|
| 64 |
+
num_return_sequences=1,
|
| 65 |
+
num_beams=4
|
| 66 |
+
)
|
| 67 |
+
result = pipe.tokenizer.decode(output[0], skip_special_tokens=True).strip()
|
| 68 |
+
result = smart_capitalize(result)
|
| 69 |
+
print(f"Generated question: {result}")
|
| 70 |
+
# Thêm dấu chấm nếu chưa có (và không kết thúc bằng ! hay ?)
|
| 71 |
+
if not re.search(r'[.?!]$', result):
|
| 72 |
+
result += '.'
|
| 73 |
+
|
| 74 |
+
return result
|
| 75 |
+
except Exception as e:
|
| 76 |
+
return f"Lỗi khi sinh câu hỏi: {e}"
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def generate_qa_list(context, num_questions, model_choice):
|
| 81 |
+
doc = nlp(context)
|
| 82 |
+
entities = list(set([ent.text for ent in doc.ents]))
|
| 83 |
+
entities = [e for e in entities if len(e.strip().split()) <= 10]
|
| 84 |
+
|
| 85 |
+
if not entities:
|
| 86 |
+
return gr.update(visible=True), ["❌ Không tìm thấy thực thể nào để sinh câu hỏi."]
|
| 87 |
+
|
| 88 |
+
count = min(num_questions, len(entities))
|
| 89 |
+
qa_list = []
|
| 90 |
+
|
| 91 |
+
for i in range(count):
|
| 92 |
+
answer = entities[i]
|
| 93 |
+
question = generate_question(context, answer, model_choice)
|
| 94 |
+
answer = smart_capitalize(entities[i])
|
| 95 |
+
qa = f"**{question}**\n<details><summary>Hiện câu trả lời</summary><p>{answer}</p></details>"
|
| 96 |
+
qa_list.append(qa)
|
| 97 |
+
|
| 98 |
+
return gr.update(visible=True), qa_list
|
| 99 |
+
|
| 100 |
+
# Tách phần phân tích context và cập nhật slider
|
| 101 |
+
def analyze_context(context):
|
| 102 |
+
doc = nlp(context)
|
| 103 |
+
entities = list(set([ent.text for ent in doc.ents]))
|
| 104 |
+
entities = [e for e in entities if len(e.strip().split()) <= 10]
|
| 105 |
+
entity_count = len(entities)
|
| 106 |
+
|
| 107 |
+
if entity_count == 0:
|
| 108 |
+
return (
|
| 109 |
+
gr.update(visible=True, value="❌ Không tìm thấy thực thể nào để sinh câu hỏi."),
|
| 110 |
+
gr.update(visible=False),
|
| 111 |
+
gr.update(visible=False),
|
| 112 |
+
gr.update(visible=False)
|
| 113 |
+
)
|
| 114 |
+
else:
|
| 115 |
+
return (
|
| 116 |
+
gr.update(visible=False),
|
| 117 |
+
gr.update(visible=True, maximum=entity_count, value=min(3, entity_count), label=f"Số câu hỏi (Tối đa: {entity_count})"),
|
| 118 |
+
gr.update(visible=True),
|
| 119 |
+
gr.update(visible=True)
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
with gr.Blocks() as demo:
|
| 123 |
+
gr.Markdown("## Hệ thống sinh câu hỏi từ Context bằng ProphetNet + spaCy NER")
|
| 124 |
+
|
| 125 |
+
with gr.Row():
|
| 126 |
+
with gr.Column(scale=4):
|
| 127 |
+
context_input = gr.Textbox(label="Nhập Context", lines=15, placeholder="Nhập đoạn văn bản...")
|
| 128 |
+
elapsed_time_md = gr.Markdown(visible=False)
|
| 129 |
+
with gr.Column(scale=1):
|
| 130 |
+
model_choice = gr.Dropdown(
|
| 131 |
+
label="Chọn mô hình",
|
| 132 |
+
choices=list(MODEL_PATHS.keys()),
|
| 133 |
+
value="prophetnet1"
|
| 134 |
+
)
|
| 135 |
+
num_input = gr.Slider(label="Số câu hỏi", minimum=1, maximum=5, value=3, step=1, visible=False)
|
| 136 |
+
generate_btn = gr.Button("Sinh câu hỏi", visible=False)
|
| 137 |
+
|
| 138 |
+
# Thông báo đang xử lý hoặc không tìm thấy
|
| 139 |
+
status_message = gr.Markdown(visible=False)
|
| 140 |
+
|
| 141 |
+
# Kết quả hiển thị tại đây
|
| 142 |
+
with gr.Column(visible=False) as output_container:
|
| 143 |
+
result_md_list = [gr.Markdown(visible=False) for _ in range(5)]
|
| 144 |
+
|
| 145 |
+
# Xử lý khi bấm nút sinh câu hỏi
|
| 146 |
+
def run_generation(context, num_questions, model_choice):
|
| 147 |
+
start_time = time.time()
|
| 148 |
+
visible_container, qa_list = generate_qa_list(context, num_questions, model_choice)
|
| 149 |
+
status_hide = gr.update(visible=False)
|
| 150 |
+
updates = []
|
| 151 |
+
|
| 152 |
+
for i in range(5):
|
| 153 |
+
if i < len(qa_list):
|
| 154 |
+
updates.append(gr.update(value=qa_list[i], visible=True))
|
| 155 |
+
else:
|
| 156 |
+
updates.append(gr.update(visible=False))
|
| 157 |
+
|
| 158 |
+
elapsed = time.time() - start_time
|
| 159 |
+
elapsed_msg = f"⏱️ Thời gian xử lý: {elapsed:.2f} giây"
|
| 160 |
+
elapsed_md = gr.update(value=elapsed_msg, visible=True)
|
| 161 |
+
|
| 162 |
+
return [status_hide, visible_container, elapsed_md] + updates
|
| 163 |
+
|
| 164 |
+
# Khi người dùng thay đổi context, tự động phân tích thực thể và cập nhật slider
|
| 165 |
+
context_input.change(
|
| 166 |
+
fn=analyze_context,
|
| 167 |
+
inputs=[context_input],
|
| 168 |
+
outputs=[status_message, num_input, generate_btn, elapsed_time_md]
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
def show_processing():
|
| 172 |
+
return gr.update(value="⏳ Đang xử lý...", visible=True)
|
| 173 |
+
|
| 174 |
+
generate_btn.click(
|
| 175 |
+
fn=show_processing,
|
| 176 |
+
inputs=[],
|
| 177 |
+
outputs=[status_message]
|
| 178 |
+
).then(
|
| 179 |
+
fn=run_generation,
|
| 180 |
+
inputs=[context_input, num_input, model_choice],
|
| 181 |
+
outputs=[status_message, output_container, elapsed_time_md] + result_md_list
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
demo.launch()
|
| 185 |
+
|
| 186 |
+
# #/Users/trantieuman/anaconda3/bin/python /Users/trantieuman/Documents/NLP/project/demo.py
|
requirements.txt
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
aiofiles==24.1.0
|
| 2 |
+
aiohappyeyeballs==2.6.1
|
| 3 |
+
aiohttp==3.12.14
|
| 4 |
+
aiosignal==1.4.0
|
| 5 |
+
annotated-types==0.7.0
|
| 6 |
+
anyio==4.9.0
|
| 7 |
+
attrs==25.3.0
|
| 8 |
+
audioop-lts==0.2.1
|
| 9 |
+
Brotli==1.1.0
|
| 10 |
+
certifi==2025.7.9
|
| 11 |
+
charset-normalizer==3.4.2
|
| 12 |
+
click==8.2.1
|
| 13 |
+
datasets==4.0.0
|
| 14 |
+
dill==0.3.8
|
| 15 |
+
fastapi==0.116.1
|
| 16 |
+
ffmpy==0.6.0
|
| 17 |
+
filelock==3.18.0
|
| 18 |
+
frozenlist==1.7.0
|
| 19 |
+
fsspec==2025.3.0
|
| 20 |
+
gradio==5.38.0
|
| 21 |
+
gradio_client==1.11.0
|
| 22 |
+
groovy==0.1.2
|
| 23 |
+
h11==0.16.0
|
| 24 |
+
hf-xet==1.1.5
|
| 25 |
+
httpcore==1.0.9
|
| 26 |
+
httpx==0.28.1
|
| 27 |
+
huggingface-hub==0.33.2
|
| 28 |
+
idna==3.10
|
| 29 |
+
Jinja2==3.1.6
|
| 30 |
+
markdown-it-py==3.0.0
|
| 31 |
+
MarkupSafe==3.0.2
|
| 32 |
+
mdurl==0.1.2
|
| 33 |
+
multidict==6.6.3
|
| 34 |
+
multiprocess==0.70.16
|
| 35 |
+
numpy==2.3.1
|
| 36 |
+
orjson==3.11.0
|
| 37 |
+
packaging==25.0
|
| 38 |
+
pandas==2.3.1
|
| 39 |
+
pillow==11.3.0
|
| 40 |
+
propcache==0.3.2
|
| 41 |
+
pyarrow==20.0.0
|
| 42 |
+
pydantic==2.11.7
|
| 43 |
+
pydantic_core==2.33.2
|
| 44 |
+
pydub==0.25.1
|
| 45 |
+
Pygments==2.19.2
|
| 46 |
+
python-dateutil==2.9.0.post0
|
| 47 |
+
python-multipart==0.0.20
|
| 48 |
+
pytz==2025.2
|
| 49 |
+
PyYAML==6.0.2
|
| 50 |
+
requests==2.32.4
|
| 51 |
+
rich==14.0.0
|
| 52 |
+
ruff==0.12.3
|
| 53 |
+
safehttpx==0.1.6
|
| 54 |
+
safetensors==0.5.3
|
| 55 |
+
semantic-version==2.10.0
|
| 56 |
+
shellingham==1.5.4
|
| 57 |
+
six==1.17.0
|
| 58 |
+
sniffio==1.3.1
|
| 59 |
+
starlette==0.47.1
|
| 60 |
+
tomlkit==0.13.3
|
| 61 |
+
tqdm==4.67.1
|
| 62 |
+
typer==0.16.0
|
| 63 |
+
typing-inspection==0.4.1
|
| 64 |
+
typing_extensions==4.14.1
|
| 65 |
+
tzdata==2025.2
|
| 66 |
+
urllib3==2.5.0
|
| 67 |
+
uvicorn==0.35.0
|
| 68 |
+
websockets==15.0.1
|
| 69 |
+
xxhash==3.5.0
|
| 70 |
+
yarl==1.20.1
|
test.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# finetuned model
|
| 2 |
+
# Import các thư viện cần thiết
|
| 3 |
+
from transformers import ProphetNetTokenizer, ProphetNetForConditionalGeneration, pipeline
|
| 4 |
+
import torch
|
| 5 |
+
import os
|
| 6 |
+
import pandas as pd
|
| 7 |
+
import time
|
| 8 |
+
from datasets import Dataset
|
| 9 |
+
|
| 10 |
+
# Hàm tải dữ liệu Parquet với xử lý lỗi và thử lại
|
| 11 |
+
def load_squad_parquet(split='train', max_retries=3, delay=5):
|
| 12 |
+
splits = {'train': 'plain_text/train-00000-of-00001.parquet'}
|
| 13 |
+
path = "hf://datasets/rajpurkar/squad/" + splits[split]
|
| 14 |
+
for attempt in range(max_retries):
|
| 15 |
+
try:
|
| 16 |
+
df = pd.read_parquet(path)
|
| 17 |
+
print(f"Tải tập dữ liệu SQuAD {split} thành công sau {attempt + 1} lần thử!")
|
| 18 |
+
return df
|
| 19 |
+
except Exception as e:
|
| 20 |
+
print(f"Lần thử {attempt + 1}/{max_retries} thất bại: {e}")
|
| 21 |
+
if attempt < max_retries - 1:
|
| 22 |
+
print(f"Đợi {delay} giây trước khi thử lại...")
|
| 23 |
+
time.sleep(delay)
|
| 24 |
+
else:
|
| 25 |
+
print("Đã hết số lần thử. Vui lòng kiểm tra kết nối internet hoặc cài đặt lại môi trường.")
|
| 26 |
+
return None
|
| 27 |
+
|
| 28 |
+
# Tải tập dữ liệu SQuAD (chỉ tải train để kiểm tra)
|
| 29 |
+
train_df = load_squad_parquet('train')
|
| 30 |
+
if train_df is None:
|
| 31 |
+
raise ValueError("Không thể tải tập dữ liệu SQuAD. Vui lòng kiểm tra kết nối internet hoặc cài đặt lại môi trường.")
|
| 32 |
+
|
| 33 |
+
# Chuyển đổi DataFrame thành Dataset để tương thích với pipeline
|
| 34 |
+
train_ds = Dataset.from_pandas(train_df)
|
| 35 |
+
|
| 36 |
+
# Đường dẫn đến thư mục chứa mô hình và tokenizer đã tinh chỉnh
|
| 37 |
+
model_dir = "/Users/trantieuman/Downloads/prophetnet_1epoch/prophetnet_context_to_question_finetuned"
|
| 38 |
+
|
| 39 |
+
# Kiểm tra xem thư mục tồn tại
|
| 40 |
+
if not os.path.exists(model_dir):
|
| 41 |
+
raise FileNotFoundError(f"Thư mục {model_dir} không tồn tại. Vui lòng kiểm tra lại đường dẫn.")
|
| 42 |
+
|
| 43 |
+
# Danh sách file cần thiết cho mô hình và tokenizer
|
| 44 |
+
required_model_files = ['config.json', 'model.safetensors'] # Chỉ cần model.safetensors vì đã sử dụng định dạng này
|
| 45 |
+
required_tokenizer_files = ['prophetnet.tokenizer', 'tokenizer_config.json'] # File tokenizer cần thiết
|
| 46 |
+
all_files = os.listdir(model_dir)
|
| 47 |
+
missing_model_files = [f for f in required_model_files if f not in all_files]
|
| 48 |
+
missing_tokenizer_files = [f for f in required_tokenizer_files if f not in all_files]
|
| 49 |
+
|
| 50 |
+
if missing_model_files or missing_tokenizer_files:
|
| 51 |
+
print(f"Thiếu file trong {model_dir}:")
|
| 52 |
+
if missing_model_files:
|
| 53 |
+
print(f" - File mô hình thiếu: {missing_model_files}")
|
| 54 |
+
if missing_tokenizer_files:
|
| 55 |
+
print(f" - File tokenizer thiếu: {missing_tokenizer_files}")
|
| 56 |
+
raise FileNotFoundError("Vui lòng cung cấp đầy đủ file mô hình và tokenizer.")
|
| 57 |
+
|
| 58 |
+
# Khởi tạo tokenizer và mô hình từ thư mục đã tinh chỉnh
|
| 59 |
+
try:
|
| 60 |
+
# Chỉ định rõ ràng rằng sử dụng định dạng safetensors
|
| 61 |
+
tokenizer = ProphetNetTokenizer.from_pretrained(model_dir)
|
| 62 |
+
model = ProphetNetForConditionalGeneration.from_pretrained(model_dir)
|
| 63 |
+
print("Tải mô hình và tokenizer từ thư mục đã tinh chỉnh thành công!")
|
| 64 |
+
except Exception as e:
|
| 65 |
+
raise RuntimeError(f"Lỗi khi tải mô hình/tokenizer: {e}. Vui lòng kiểm tra cấu trúc thư mục hoặc cập nhật thư viện transformers.")
|
| 66 |
+
|
| 67 |
+
# Tạo pipeline để tạo câu hỏi (question generation)
|
| 68 |
+
pipe = pipeline(
|
| 69 |
+
"text2text-generation",
|
| 70 |
+
model=model,
|
| 71 |
+
tokenizer=tokenizer,
|
| 72 |
+
max_length=256, # Giới hạn độ dài tối đa của câu hỏi
|
| 73 |
+
num_return_sequences=1, # Tạo 1 câu hỏi duy nhất
|
| 74 |
+
device=0 if torch.cuda.is_available() else -1 # Sử dụng GPU nếu có, mặc định CPU
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
# Hàm tạo câu hỏi từ context và answer
|
| 78 |
+
def generate_question(context, answer):
|
| 79 |
+
# Định dạng input theo cách mô hình đã được tinh chỉnh
|
| 80 |
+
input_text = f"context: {context} answer: {answer}"
|
| 81 |
+
try:
|
| 82 |
+
result = pipe(input_text)[0]['generated_text']
|
| 83 |
+
return result
|
| 84 |
+
except Exception as e:
|
| 85 |
+
print(f"Lỗi khi tạo câu hỏi: {e}")
|
| 86 |
+
return None
|
| 87 |
+
|
| 88 |
+
# Thử nghiệm pipeline với một ví dụ
|
| 89 |
+
context = "The Vatican Apostolic Library is located in Vatican City."
|
| 90 |
+
answer = "Vatican City"
|
| 91 |
+
question = generate_question(context, answer)
|
| 92 |
+
if question:
|
| 93 |
+
print(f"Context: {context}")
|
| 94 |
+
print(f"Answer: {answer}")
|
| 95 |
+
print(f"Generated Question: {question}")
|
| 96 |
+
|
| 97 |
+
# (Tùy chọn) Kiểm tra với dữ liệu từ SQuAD
|
| 98 |
+
sample = train_ds[0] # Lấy mẫu đầu tiên từ tập dữ liệu
|
| 99 |
+
context_sample = sample['context']
|
| 100 |
+
answer_sample = sample['answers']['text'][0] if sample['answers']['text'] else "No answer"
|
| 101 |
+
question_sample = generate_question(context_sample, answer_sample)
|
| 102 |
+
if question_sample:
|
| 103 |
+
print(f"\nSample Context: {context_sample}")
|
| 104 |
+
print(f"Sample Answer: {answer_sample}")
|
| 105 |
+
print(f"Generated Question: {question_sample}")
|
| 106 |
+
# /Users/trantieuman/anaconda3/bin/python /Users/trantieuman/Documents/NLP/project/test.py
|