ManB2207540 commited on
Commit
c750faa
·
1 Parent(s): e33772f

generate demo

Browse files
Files changed (4) hide show
  1. README.md +153 -11
  2. app.py +186 -0
  3. requirements.txt +70 -0
  4. test.py +106 -0
README.md CHANGED
@@ -1,14 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
- title: Demo Question Generation
3
- emoji: 📈
4
- colorFrom: indigo
5
- colorTo: red
6
- sdk: gradio
7
- sdk_version: 5.38.0
8
- app_file: app.py
9
- pinned: false
10
- license: unknown
11
- short_description: A demo for using transformer models for question generation.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  ---
13
 
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
+ # 📘 Hệ thống sinh câu hỏi từ Context (ProphetNet + spaCy NER)
2
+
3
+ Dự án này triển khai một hệ thống sinh câu hỏi tự động dựa trên một đoạn văn bản (context) và các thực thể được trích xuất từ đó. Nó sử dụng mô hình ProphetNet đã được fine-tuned để sinh câu hỏi và thư viện spaCy cho việc trích xuất thực thể có tên (Named Entity Recognition - NER). Giao diện người dùng được xây dựng bằng Gradio, cho phép tương tác dễ dàng.
4
+
5
+ ---
6
+
7
+ ## 🚀 Tính năng chính
8
+
9
+ - **Sinh câu hỏi:** Dựa trên một đoạn văn bản (context) và các câu trả lời (entities) được trích xuất.
10
+ - **Trích xuất thực thể (NER):** Sử dụng `en_core_web_md` của spaCy để xác định các thực thể tiềm năng làm câu trả lời.
11
+ - **Hỗ trợ nhiều mô hình:** Cho phép lựa chọn giữa các phiên bản mô hình ProphetNet đã được huấn luyện.
12
+ - **Giao diện web thân thiện:** Được xây dựng bằng Gradio, dễ dàng sử dụng và kiểm tra.
13
+ - **Khả năng tái tạo:** Hướng dẫn chi tiết để bạn có thể cài đặt và chạy dự án này trên máy của mình.
14
+
15
+ ---
16
+
17
+ ## 🛠️ Yêu cầu hệ thống
18
+
19
+ - Python 3.8+ (nên sử dụng môi trường ảo như Conda hoặc `venv`).
20
+ - Ít nhất 8GB RAM (để tải các mô hình ngôn ngữ lớn).
21
+ - Đề xuất có GPU với VRAM đủ lớn (ví dụ: 8GB+) để có hiệu suất sinh câu hỏi nhanh hơn. Nếu không có GPU, mô hình sẽ chạy trên CPU nhưng có thể chậm hơn đáng kể.
22
+
23
  ---
24
+
25
+ ## 📦 Hướng dẫn cài đặt và chạy dự án
26
+
27
+ Bạn có thể sử dụng `conda` (nếu đã cài Anaconda) hoặc `venv` để tạo môi trường ảo.
28
+
29
+ ### Phương pháp 1: Sử dụng Conda (Khuyến nghị)
30
+
31
+ Nếu bạn đã cài đặt **Anaconda** hoặc **Miniconda**:
32
+
33
+ 1. **Tạo kích hoạt môi trường Conda mới:**
34
+
35
+ ```bash
36
+ conda create -n qg_env python=3.9
37
+ conda activate qg_env
38
+ ```
39
+
40
+ (Bạn có thể chọn phiên bản Python khác như 3.8 hoặc 3.10 nếu muốn, nhưng 3.9 là một lựa chọn tốt).
41
+
42
+ 2. **Cài đặt các thư viện từ `requirements.txt`:**
43
+ **Quan trọng:** Bạn cần file `requirements.txt` chứa danh sách các thư viện được sử dụng trong dự án này. Nếu bạn chưa có, hãy tạo nó:
44
+
45
+ ```bash
46
+ # Điều hướng đến thư mục gốc của dự án này
47
+ cd path/to/your/project_folder
48
+ # Tạo file requirements.txt
49
+ pip freeze > requirements.txt
50
+ ```
51
+
52
+ Sau khi có file `requirements.txt` trong thư mục gốc của dự án, hãy chạy:
53
+
54
+ ```bash
55
+ pip install -r requirements.txt
56
+ ```
57
+
58
+ 3. **Tải mô hình ngôn ngữ `en_core_web_md` của spaCy:**
59
+
60
+ ```bash
61
+ python -m spacy download en_core_web_md
62
+ ```
63
+
64
+ 4. **Tải và đặt các mô hình ProphetNet:**
65
+ Dự án này sử dụng các mô hình ProphetNet đã được fine-tuned. Bạn cần tải chúng về và đặt vào các đường dẫn chính xác như đã khai báo trong code (hoặc chỉnh sửa code cho phù hợp với đường dẫn của bạn).
66
+ Theo code mẫu:
67
+
68
+ - `prophetnet1`: `/Users/trantieuman/Downloads/prophetnet_1epoch/prophetnet_context_to_question_finetuned`
69
+ - `prophetnet2`: `/Users/trantieuman/Downloads/prophetnet_2epoch_final/final_model`
70
+ - (Nếu có) `prophetnet3`: `/path/to/prophetnet_model_3`
71
+
72
+ **Lưu ý:** Để dễ dàng quản lý, bạn nên tạo một thư mục con trong dự án (ví dụ: `models/prophetnet_1epoch_finetuned`) và đặt các mô hình vào đó, sau đó cập nhật `MODEL_PATHS` trong code của bạn thành các đường dẫn tương đối. Ví dụ:
73
+
74
+ ```python
75
+ MODEL_PATHS = {
76
+ "prophetnet1": "./models/prophetnet_1epoch_finetuned",
77
+ "prophetnet2": "./models/prophetnet_2epoch_final",
78
+ # ...
79
+ }
80
+ ```
81
+
82
+ Đảm bảo các thư mục mô hình này chứa các tệp như `config.json`, `pytorch_model.bin`, `tokenizer_config.json`, `vocab.json`, v.v.
83
+
84
+ ### Phương pháp 2: Sử dụng `venv`
85
+
86
+ 1. **Tạo và kích hoạt môi trường ảo mới:**
87
+
88
+ ```bash
89
+ # Điều hướng đến thư mục gốc của dự án này
90
+ cd path/to/your/project_folder
91
+
92
+ # Tạo môi trường ảo
93
+ python -m venv venv_qg
94
+
95
+ # Kích hoạt môi trường ảo
96
+ source venv_qg/bin/activate # Trên Linux/macOS
97
+ # Hoặc: .\venv_qg\Scripts\activate # Trên Windows
98
+ ```
99
+
100
+ 2. **Cài đặt các thư viện từ `requirements.txt`:**
101
+ Tương tự như bước 2 của phương pháp Conda, tạo file `requirements.txt` nếu chưa có:
102
+
103
+ ```bash
104
+ pip freeze > requirements.txt
105
+ ```
106
+
107
+ Sau đó cài đặt:
108
+
109
+ ```bash
110
+ pip install -r requirements.txt
111
+ ```
112
+
113
+ 3. **Tải mô hình ngôn ngữ `en_core_web_md` của spaCy:**
114
+
115
+ ```bash
116
+ python -m spacy download en_core_web_md
117
+ ```
118
+
119
+ 4. **Tải và đặt các mô hình ProphetNet:**
120
+ Tương tự như bước 4 của phương pháp Conda, đảm bảo các file mô hình ProphetNet đã fine-tuned được đặt đúng đường dẫn.
121
+
122
+ ---
123
+
124
+ ## 🏃 Cách chạy dự án
125
+
126
+ Sau khi đã hoàn thành các bước cài đặt và kích hoạt môi trường ảo:
127
+
128
+ 1. **Đảm bảo bạn đang ở trong thư mục gốc của dự án.**
129
+
130
+ 2. **Chạy script chính:**
131
+
132
+ ```bash
133
+ python demo.py
134
+ ```
135
+
136
+ 3. **Mở trình duyệt:**
137
+ Khi ứng dụng Gradio khởi chạy, bạn sẽ thấy một URL trong terminal (thường là `http://127.0.0.1:7860` hoặc tương tự). Sao chép URL này và dán vào trình duyệt web của bạn để tương tác với giao diện hệ thống sinh câu hỏi.
138
+
139
+ ---
140
+
141
+ ## ⚠️ Lưu ý quan trọng
142
+
143
+ - **Đường dẫn mô hình:** Hãy kiểm tra và điều chỉnh các đường dẫn trong biến `MODEL_PATHS` trong code của bạn (`demo.py` hoặc tên file tương ứng) để chúng trỏ đến đúng vị trí các thư mục mô hình ProphetNet đã được tải về trên máy của bạn.
144
+ - **Hiệu suất GPU:** Việc sử dụng GPU sẽ cải thiện đáng kể tốc độ sinh câu hỏi. Đảm bảo cài đặt CUDA và PyTorch với hỗ trợ CUDA nếu bạn muốn tận dụng GPU.
145
+ - **Kiểm tra cache:** Việc sử dụng quá nhiều model co thể gây tràn cache.
146
+ ```bash
147
+ du -sh ~/.cache/huggingface/hub
148
+ ```
149
+ - **Xoá cache:** Việc sử dụng quá nhiều model co thể gây tràn cache hãy xoá nếu không sử dụng.
150
+ ```bash
151
+ huggingface-cli delete-cache
152
+ ```
153
+
154
  ---
155
 
156
+ Hy vọng hướng dẫn này sẽ giúp bạn và những người khác dễ dàng thiết lập và chạy dự án của mình!
app.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import gradio as gr
3
+ import spacy
4
+ from transformers import ProphetNetTokenizer, ProphetNetForConditionalGeneration, pipeline
5
+ import torch
6
+ import time
7
+ import re
8
+
9
+ nlp = spacy.load("en_core_web_md")
10
+
11
+ MODEL_PATHS = {
12
+ "prophetnet2": "ManB2207540/prophetnet_SQuAD_1.1-2epoch_break",
13
+ "prophetnet tieu chuan": "microsoft/prophetnet-large-uncased-squad-qg"
14
+ }
15
+
16
+ def load_pipeline(model_path):
17
+ tokenizer = ProphetNetTokenizer.from_pretrained(model_path)
18
+ model = ProphetNetForConditionalGeneration.from_pretrained(model_path)
19
+ return pipeline(
20
+ "text2text-generation",
21
+ model=model,
22
+ tokenizer=tokenizer,
23
+ max_length=256,
24
+ num_return_sequences=1,
25
+ device=0 if torch.cuda.is_available() else -1
26
+ )
27
+
28
+ pipeline_cache = {}
29
+
30
+ def get_pipeline(model_name):
31
+ model_path = MODEL_PATHS[model_name]
32
+ if model_name not in pipeline_cache:
33
+ pipeline_cache[model_name] = load_pipeline(model_path)
34
+ return pipeline_cache[model_name]
35
+
36
+ # Tự viết hàm capitalize thông minh
37
+
38
+
39
+ def smart_capitalize(text):
40
+ # Giữ nguyên cách viết hoa phần còn lại, chỉ viết hoa chữ đầu nếu cần
41
+ text = text.strip()
42
+ if not text:
43
+ return text
44
+ text = text[0].upper() + text[1:]
45
+ if not re.search(r'[.?!]$', text):
46
+ text += '.'
47
+ return text
48
+
49
+ def generate_question(context, answer, model_name):
50
+ pipe = get_pipeline(model_name)
51
+ tokenizer = pipe.tokenizer
52
+ prompt = f"context: {context} answer: {answer}"
53
+
54
+ # Cắt prompt nếu vượt quá giới hạn token
55
+ encoded = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
56
+ input_ids = encoded["input_ids"]
57
+ attention_mask = encoded["attention_mask"]
58
+
59
+ try:
60
+ output = pipe.model.generate(
61
+ input_ids=input_ids.to(pipe.model.device),
62
+ attention_mask=attention_mask.to(pipe.model.device),
63
+ max_length=64,
64
+ num_return_sequences=1,
65
+ num_beams=4
66
+ )
67
+ result = pipe.tokenizer.decode(output[0], skip_special_tokens=True).strip()
68
+ result = smart_capitalize(result)
69
+ print(f"Generated question: {result}")
70
+ # Thêm dấu chấm nếu chưa có (và không kết thúc bằng ! hay ?)
71
+ if not re.search(r'[.?!]$', result):
72
+ result += '.'
73
+
74
+ return result
75
+ except Exception as e:
76
+ return f"Lỗi khi sinh câu hỏi: {e}"
77
+
78
+
79
+
80
+ def generate_qa_list(context, num_questions, model_choice):
81
+ doc = nlp(context)
82
+ entities = list(set([ent.text for ent in doc.ents]))
83
+ entities = [e for e in entities if len(e.strip().split()) <= 10]
84
+
85
+ if not entities:
86
+ return gr.update(visible=True), ["❌ Không tìm thấy thực thể nào để sinh câu hỏi."]
87
+
88
+ count = min(num_questions, len(entities))
89
+ qa_list = []
90
+
91
+ for i in range(count):
92
+ answer = entities[i]
93
+ question = generate_question(context, answer, model_choice)
94
+ answer = smart_capitalize(entities[i])
95
+ qa = f"**{question}**\n<details><summary>Hiện câu trả lời</summary><p>{answer}</p></details>"
96
+ qa_list.append(qa)
97
+
98
+ return gr.update(visible=True), qa_list
99
+
100
+ # Tách phần phân tích context và cập nhật slider
101
+ def analyze_context(context):
102
+ doc = nlp(context)
103
+ entities = list(set([ent.text for ent in doc.ents]))
104
+ entities = [e for e in entities if len(e.strip().split()) <= 10]
105
+ entity_count = len(entities)
106
+
107
+ if entity_count == 0:
108
+ return (
109
+ gr.update(visible=True, value="❌ Không tìm thấy thực thể nào để sinh câu hỏi."),
110
+ gr.update(visible=False),
111
+ gr.update(visible=False),
112
+ gr.update(visible=False)
113
+ )
114
+ else:
115
+ return (
116
+ gr.update(visible=False),
117
+ gr.update(visible=True, maximum=entity_count, value=min(3, entity_count), label=f"Số câu hỏi (Tối đa: {entity_count})"),
118
+ gr.update(visible=True),
119
+ gr.update(visible=True)
120
+ )
121
+
122
+ with gr.Blocks() as demo:
123
+ gr.Markdown("## Hệ thống sinh câu hỏi từ Context bằng ProphetNet + spaCy NER")
124
+
125
+ with gr.Row():
126
+ with gr.Column(scale=4):
127
+ context_input = gr.Textbox(label="Nhập Context", lines=15, placeholder="Nhập đoạn văn bản...")
128
+ elapsed_time_md = gr.Markdown(visible=False)
129
+ with gr.Column(scale=1):
130
+ model_choice = gr.Dropdown(
131
+ label="Chọn mô hình",
132
+ choices=list(MODEL_PATHS.keys()),
133
+ value="prophetnet1"
134
+ )
135
+ num_input = gr.Slider(label="Số câu hỏi", minimum=1, maximum=5, value=3, step=1, visible=False)
136
+ generate_btn = gr.Button("Sinh câu hỏi", visible=False)
137
+
138
+ # Thông báo đang xử lý hoặc không tìm thấy
139
+ status_message = gr.Markdown(visible=False)
140
+
141
+ # Kết quả hiển thị tại đây
142
+ with gr.Column(visible=False) as output_container:
143
+ result_md_list = [gr.Markdown(visible=False) for _ in range(5)]
144
+
145
+ # Xử lý khi bấm nút sinh câu hỏi
146
+ def run_generation(context, num_questions, model_choice):
147
+ start_time = time.time()
148
+ visible_container, qa_list = generate_qa_list(context, num_questions, model_choice)
149
+ status_hide = gr.update(visible=False)
150
+ updates = []
151
+
152
+ for i in range(5):
153
+ if i < len(qa_list):
154
+ updates.append(gr.update(value=qa_list[i], visible=True))
155
+ else:
156
+ updates.append(gr.update(visible=False))
157
+
158
+ elapsed = time.time() - start_time
159
+ elapsed_msg = f"⏱️ Thời gian xử lý: {elapsed:.2f} giây"
160
+ elapsed_md = gr.update(value=elapsed_msg, visible=True)
161
+
162
+ return [status_hide, visible_container, elapsed_md] + updates
163
+
164
+ # Khi người dùng thay đổi context, tự động phân tích thực thể và cập nhật slider
165
+ context_input.change(
166
+ fn=analyze_context,
167
+ inputs=[context_input],
168
+ outputs=[status_message, num_input, generate_btn, elapsed_time_md]
169
+ )
170
+
171
+ def show_processing():
172
+ return gr.update(value="⏳ Đang xử lý...", visible=True)
173
+
174
+ generate_btn.click(
175
+ fn=show_processing,
176
+ inputs=[],
177
+ outputs=[status_message]
178
+ ).then(
179
+ fn=run_generation,
180
+ inputs=[context_input, num_input, model_choice],
181
+ outputs=[status_message, output_container, elapsed_time_md] + result_md_list
182
+ )
183
+
184
+ demo.launch()
185
+
186
+ # #/Users/trantieuman/anaconda3/bin/python /Users/trantieuman/Documents/NLP/project/demo.py
requirements.txt ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==24.1.0
2
+ aiohappyeyeballs==2.6.1
3
+ aiohttp==3.12.14
4
+ aiosignal==1.4.0
5
+ annotated-types==0.7.0
6
+ anyio==4.9.0
7
+ attrs==25.3.0
8
+ audioop-lts==0.2.1
9
+ Brotli==1.1.0
10
+ certifi==2025.7.9
11
+ charset-normalizer==3.4.2
12
+ click==8.2.1
13
+ datasets==4.0.0
14
+ dill==0.3.8
15
+ fastapi==0.116.1
16
+ ffmpy==0.6.0
17
+ filelock==3.18.0
18
+ frozenlist==1.7.0
19
+ fsspec==2025.3.0
20
+ gradio==5.38.0
21
+ gradio_client==1.11.0
22
+ groovy==0.1.2
23
+ h11==0.16.0
24
+ hf-xet==1.1.5
25
+ httpcore==1.0.9
26
+ httpx==0.28.1
27
+ huggingface-hub==0.33.2
28
+ idna==3.10
29
+ Jinja2==3.1.6
30
+ markdown-it-py==3.0.0
31
+ MarkupSafe==3.0.2
32
+ mdurl==0.1.2
33
+ multidict==6.6.3
34
+ multiprocess==0.70.16
35
+ numpy==2.3.1
36
+ orjson==3.11.0
37
+ packaging==25.0
38
+ pandas==2.3.1
39
+ pillow==11.3.0
40
+ propcache==0.3.2
41
+ pyarrow==20.0.0
42
+ pydantic==2.11.7
43
+ pydantic_core==2.33.2
44
+ pydub==0.25.1
45
+ Pygments==2.19.2
46
+ python-dateutil==2.9.0.post0
47
+ python-multipart==0.0.20
48
+ pytz==2025.2
49
+ PyYAML==6.0.2
50
+ requests==2.32.4
51
+ rich==14.0.0
52
+ ruff==0.12.3
53
+ safehttpx==0.1.6
54
+ safetensors==0.5.3
55
+ semantic-version==2.10.0
56
+ shellingham==1.5.4
57
+ six==1.17.0
58
+ sniffio==1.3.1
59
+ starlette==0.47.1
60
+ tomlkit==0.13.3
61
+ tqdm==4.67.1
62
+ typer==0.16.0
63
+ typing-inspection==0.4.1
64
+ typing_extensions==4.14.1
65
+ tzdata==2025.2
66
+ urllib3==2.5.0
67
+ uvicorn==0.35.0
68
+ websockets==15.0.1
69
+ xxhash==3.5.0
70
+ yarl==1.20.1
test.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # finetuned model
2
+ # Import các thư viện cần thiết
3
+ from transformers import ProphetNetTokenizer, ProphetNetForConditionalGeneration, pipeline
4
+ import torch
5
+ import os
6
+ import pandas as pd
7
+ import time
8
+ from datasets import Dataset
9
+
10
+ # Hàm tải dữ liệu Parquet với xử lý lỗi và thử lại
11
+ def load_squad_parquet(split='train', max_retries=3, delay=5):
12
+ splits = {'train': 'plain_text/train-00000-of-00001.parquet'}
13
+ path = "hf://datasets/rajpurkar/squad/" + splits[split]
14
+ for attempt in range(max_retries):
15
+ try:
16
+ df = pd.read_parquet(path)
17
+ print(f"Tải tập dữ liệu SQuAD {split} thành công sau {attempt + 1} lần thử!")
18
+ return df
19
+ except Exception as e:
20
+ print(f"Lần thử {attempt + 1}/{max_retries} thất bại: {e}")
21
+ if attempt < max_retries - 1:
22
+ print(f"Đợi {delay} giây trước khi thử lại...")
23
+ time.sleep(delay)
24
+ else:
25
+ print("Đã hết số lần thử. Vui lòng kiểm tra kết nối internet hoặc cài đặt lại môi trường.")
26
+ return None
27
+
28
+ # Tải tập dữ liệu SQuAD (chỉ tải train để kiểm tra)
29
+ train_df = load_squad_parquet('train')
30
+ if train_df is None:
31
+ raise ValueError("Không thể tải tập dữ liệu SQuAD. Vui lòng kiểm tra kết nối internet hoặc cài đặt lại môi trường.")
32
+
33
+ # Chuyển đổi DataFrame thành Dataset để tương thích với pipeline
34
+ train_ds = Dataset.from_pandas(train_df)
35
+
36
+ # Đường dẫn đến thư mục chứa mô hình và tokenizer đã tinh chỉnh
37
+ model_dir = "/Users/trantieuman/Downloads/prophetnet_1epoch/prophetnet_context_to_question_finetuned"
38
+
39
+ # Kiểm tra xem thư mục tồn tại
40
+ if not os.path.exists(model_dir):
41
+ raise FileNotFoundError(f"Thư mục {model_dir} không tồn tại. Vui lòng kiểm tra lại đường dẫn.")
42
+
43
+ # Danh sách file cần thiết cho mô hình và tokenizer
44
+ required_model_files = ['config.json', 'model.safetensors'] # Chỉ cần model.safetensors vì đã sử dụng định dạng này
45
+ required_tokenizer_files = ['prophetnet.tokenizer', 'tokenizer_config.json'] # File tokenizer cần thiết
46
+ all_files = os.listdir(model_dir)
47
+ missing_model_files = [f for f in required_model_files if f not in all_files]
48
+ missing_tokenizer_files = [f for f in required_tokenizer_files if f not in all_files]
49
+
50
+ if missing_model_files or missing_tokenizer_files:
51
+ print(f"Thiếu file trong {model_dir}:")
52
+ if missing_model_files:
53
+ print(f" - File mô hình thiếu: {missing_model_files}")
54
+ if missing_tokenizer_files:
55
+ print(f" - File tokenizer thiếu: {missing_tokenizer_files}")
56
+ raise FileNotFoundError("Vui lòng cung cấp đầy đủ file mô hình và tokenizer.")
57
+
58
+ # Khởi tạo tokenizer và mô hình từ thư mục đã tinh chỉnh
59
+ try:
60
+ # Chỉ định rõ ràng rằng sử dụng định dạng safetensors
61
+ tokenizer = ProphetNetTokenizer.from_pretrained(model_dir)
62
+ model = ProphetNetForConditionalGeneration.from_pretrained(model_dir)
63
+ print("Tải mô hình và tokenizer từ thư mục đã tinh chỉnh thành công!")
64
+ except Exception as e:
65
+ raise RuntimeError(f"Lỗi khi tải mô hình/tokenizer: {e}. Vui lòng kiểm tra cấu trúc thư mục hoặc cập nhật thư viện transformers.")
66
+
67
+ # Tạo pipeline để tạo câu hỏi (question generation)
68
+ pipe = pipeline(
69
+ "text2text-generation",
70
+ model=model,
71
+ tokenizer=tokenizer,
72
+ max_length=256, # Giới hạn độ dài tối đa của câu hỏi
73
+ num_return_sequences=1, # Tạo 1 câu hỏi duy nhất
74
+ device=0 if torch.cuda.is_available() else -1 # Sử dụng GPU nếu có, mặc định CPU
75
+ )
76
+
77
+ # Hàm tạo câu hỏi từ context và answer
78
+ def generate_question(context, answer):
79
+ # Định dạng input theo cách mô hình đã được tinh chỉnh
80
+ input_text = f"context: {context} answer: {answer}"
81
+ try:
82
+ result = pipe(input_text)[0]['generated_text']
83
+ return result
84
+ except Exception as e:
85
+ print(f"Lỗi khi tạo câu hỏi: {e}")
86
+ return None
87
+
88
+ # Thử nghiệm pipeline với một ví dụ
89
+ context = "The Vatican Apostolic Library is located in Vatican City."
90
+ answer = "Vatican City"
91
+ question = generate_question(context, answer)
92
+ if question:
93
+ print(f"Context: {context}")
94
+ print(f"Answer: {answer}")
95
+ print(f"Generated Question: {question}")
96
+
97
+ # (Tùy chọn) Kiểm tra với dữ liệu từ SQuAD
98
+ sample = train_ds[0] # Lấy mẫu đầu tiên từ tập dữ liệu
99
+ context_sample = sample['context']
100
+ answer_sample = sample['answers']['text'][0] if sample['answers']['text'] else "No answer"
101
+ question_sample = generate_question(context_sample, answer_sample)
102
+ if question_sample:
103
+ print(f"\nSample Context: {context_sample}")
104
+ print(f"Sample Answer: {answer_sample}")
105
+ print(f"Generated Question: {question_sample}")
106
+ # /Users/trantieuman/anaconda3/bin/python /Users/trantieuman/Documents/NLP/project/test.py