ManB2207540 commited on
Commit
a291170
·
1 Parent(s): a6ed7ac

modify app.py for T5 and requirements.txt

Browse files
Files changed (2) hide show
  1. app.py +86 -34
  2. requirements.txt +1 -0
app.py CHANGED
@@ -1,41 +1,74 @@
1
 
2
  import gradio as gr
3
  import spacy
4
- from transformers import ProphetNetTokenizer, ProphetNetForConditionalGeneration, pipeline
 
5
  import torch
6
  import time
7
  import re
8
- import os # Đảm bảo bạn đã import os
9
 
10
  # Tải mô hình spaCy
11
  if not spacy.util.is_package("en_core_web_md"):
12
  print("Đang tải mô hình spaCy 'en_core_web_md'...")
13
- spacy.cli.download("en_core_web_md") # <--- Lỗi xảy ra ở đây
14
  nlp = spacy.load("en_core_web_md")
15
- print("✅ Đã tải/nạp mô hình spaCy.")
 
16
  MODEL_PATHS = {
17
- "prophetnet2": "ManB2207540/prophetnet_SQuAD_1.1-2epoch_break",
18
- "prophetnet tieu chuan": "microsoft/prophetnet-large-uncased-squad-qg"
 
 
19
  }
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  def load_pipeline(model_path):
22
- tokenizer = ProphetNetTokenizer.from_pretrained(model_path)
23
- model = ProphetNetForConditionalGeneration.from_pretrained(model_path)
24
- return pipeline(
25
- "text2text-generation",
26
- model=model,
27
- tokenizer=tokenizer,
28
- max_length=256,
29
- num_return_sequences=1,
30
- device=0 if torch.cuda.is_available() else -1
31
- )
 
 
 
 
 
 
 
32
 
 
33
  pipeline_cache = {}
34
 
35
  def get_pipeline(model_name):
36
  model_path = MODEL_PATHS[model_name]
37
  if model_name not in pipeline_cache:
38
- pipeline_cache[model_name] = load_pipeline(model_path)
 
 
 
39
  return pipeline_cache[model_name]
40
 
41
  # Tự viết hàm capitalize thông minh
@@ -54,9 +87,19 @@ def smart_capitalize(text):
54
  def generate_question(context, answer, model_name):
55
  pipe = get_pipeline(model_name)
56
  tokenizer = pipe.tokenizer
57
- prompt = f"context: {context} answer: {answer}"
 
 
 
 
 
 
 
 
 
 
58
 
59
- # Cắt prompt nếu vượt quá giới hạn token
60
  encoded = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
61
  input_ids = encoded["input_ids"]
62
  attention_mask = encoded["attention_mask"]
@@ -87,23 +130,28 @@ def generate_qa_list(context, num_questions, model_choice):
87
  entities = list(set([ent.text for ent in doc.ents]))
88
  entities = [e for e in entities if len(e.strip().split()) <= 10]
89
 
 
90
  if not entities:
91
- return gr.update(visible=True), ["❌ Không tìm thấy thực thể nào để sinh câu hỏi."]
92
 
 
93
  count = min(num_questions, len(entities))
94
  qa_list = []
95
 
96
  for i in range(count):
97
  answer = entities[i]
98
  question = generate_question(context, answer, model_choice)
 
 
 
99
  answer = smart_capitalize(entities[i])
100
  qa = f"**{question}**\n<details><summary>Hiện câu trả lời</summary><p>{answer}</p></details>"
101
  qa_list.append(qa)
102
 
103
- return gr.update(visible=True), qa_list
104
 
105
  # Tách phần phân tích context và cập nhật slider
106
- def analyze_context(context):
107
  doc = nlp(context)
108
  entities = list(set([ent.text for ent in doc.ents]))
109
  entities = [e for e in entities if len(e.strip().split()) <= 10]
@@ -119,13 +167,13 @@ def analyze_context(context):
119
  else:
120
  return (
121
  gr.update(visible=False),
122
- gr.update(visible=True, maximum=entity_count, value=min(3, entity_count), label=f"Số câu hỏi (Tối đa: {entity_count})"),
123
  gr.update(visible=True),
124
  gr.update(visible=True)
125
  )
126
 
127
  with gr.Blocks() as demo:
128
- gr.Markdown("## Hệ thống sinh câu hỏi từ Context bằng Seq2Seq Transformer + spaCy NER")
129
 
130
  with gr.Row():
131
  with gr.Column(scale=4):
@@ -135,9 +183,9 @@ with gr.Blocks() as demo:
135
  model_choice = gr.Dropdown(
136
  label="Chọn mô hình",
137
  choices=list(MODEL_PATHS.keys()),
138
- value="prophetnet2"
139
  )
140
- num_input = gr.Slider(label="Số câu hỏi", minimum=1, maximum=5, value=3, step=1, visible=False)
141
  generate_btn = gr.Button("Sinh câu hỏi", visible=False)
142
 
143
  # Thông báo đang xử lý hoặc không tìm thấy
@@ -145,16 +193,19 @@ with gr.Blocks() as demo:
145
 
146
  # Kết quả hiển thị tại đây
147
  with gr.Column(visible=False) as output_container:
148
- result_md_list = [gr.Markdown(visible=False) for _ in range(5)]
149
 
150
  # Xử lý khi bấm nút sinh câu hỏi
151
  def run_generation(context, num_questions, model_choice):
152
  start_time = time.time()
153
- visible_container, qa_list = generate_qa_list(context, num_questions, model_choice)
154
- status_hide = gr.update(visible=False)
155
- updates = []
 
 
156
 
157
- for i in range(5):
 
158
  if i < len(qa_list):
159
  updates.append(gr.update(value=qa_list[i], visible=True))
160
  else:
@@ -164,12 +215,12 @@ with gr.Blocks() as demo:
164
  elapsed_msg = f"⏱️ Thời gian xử lý: {elapsed:.2f} giây"
165
  elapsed_md = gr.update(value=elapsed_msg, visible=True)
166
 
167
- return [status_hide, visible_container, elapsed_md] + updates
168
 
169
  # Khi người dùng thay đổi context, tự động phân tích thực thể và cập nhật slider
170
  context_input.change(
171
  fn=analyze_context,
172
- inputs=[context_input],
173
  outputs=[status_message, num_input, generate_btn, elapsed_time_md]
174
  )
175
 
@@ -186,6 +237,7 @@ with gr.Blocks() as demo:
186
  outputs=[status_message, output_container, elapsed_time_md] + result_md_list
187
  )
188
 
189
- demo.launch(share=True)
 
190
 
191
  # #/Users/trantieuman/anaconda3/bin/python /Users/trantieuman/Documents/NLP/project/demo.py
 
1
 
2
  import gradio as gr
3
  import spacy
4
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline, AutoConfig
5
+ from transformers import ProphetNetTokenizer, ProphetNetForConditionalGeneration, T5Tokenizer, T5ForConditionalGeneration
6
  import torch
7
  import time
8
  import re
9
+ import os
10
 
11
  # Tải mô hình spaCy
12
  if not spacy.util.is_package("en_core_web_md"):
13
  print("Đang tải mô hình spaCy 'en_core_web_md'...")
14
+ spacy.cli.download("en_core_web_md")
15
  nlp = spacy.load("en_core_web_md")
16
+
17
+ # Đường dẫn mô hình
18
  MODEL_PATHS = {
19
+ "prophetnet-finetuned": "ManB2207540/prophetnet_SQuAD_1.1-2epoch_break",
20
+ "prophetnet tieu chuan": "microsoft/prophetnet-large-uncased-squad-qg",
21
+ "bart-finetuned": "mghan3624/bart_qg_finetune_squad",
22
+ "t5-small-finetuned": "tbtminh/t5-small-qg-finetuned"
23
  }
24
 
25
+ # Hàm tải mô hình T5
26
+ def load_t5_pipeline(model_path):
27
+ try:
28
+ tokenizer = T5Tokenizer.from_pretrained(model_path)
29
+ model = T5ForConditionalGeneration.from_pretrained(model_path)
30
+ return pipeline(
31
+ "text2text-generation",
32
+ model=model,
33
+ tokenizer=tokenizer,
34
+ max_length=256,
35
+ num_return_sequences=1,
36
+ device=0 if torch.cuda.is_available() else -1
37
+ )
38
+ except Exception as e:
39
+ print(f"Failed to load T5 pipeline for {model_path}: {e}")
40
+ return None
41
+
42
+ # Hàm tải mô hình chung
43
  def load_pipeline(model_path):
44
+ try:
45
+ config = AutoConfig.from_pretrained(model_path)
46
+ if getattr(config, "early_stopping", None) is None:
47
+ config.early_stopping = False
48
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
49
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_path, config=config)
50
+ return pipeline(
51
+ "text2text-generation",
52
+ model=model,
53
+ tokenizer=tokenizer,
54
+ max_length=256,
55
+ num_return_sequences=1,
56
+ device=0 if torch.cuda.is_available() else -1
57
+ )
58
+ except Exception as e:
59
+ print(f"Failed to load pipeline for {model_path}: {e}")
60
+ return None
61
 
62
+ # Cache pipeline
63
  pipeline_cache = {}
64
 
65
  def get_pipeline(model_name):
66
  model_path = MODEL_PATHS[model_name]
67
  if model_name not in pipeline_cache:
68
+ if model_name == "t5-small-finetuned":
69
+ pipeline_cache[model_name] = load_t5_pipeline(model_path)
70
+ else:
71
+ pipeline_cache[model_name] = load_pipeline(model_path)
72
  return pipeline_cache[model_name]
73
 
74
  # Tự viết hàm capitalize thông minh
 
87
  def generate_question(context, answer, model_name):
88
  pipe = get_pipeline(model_name)
89
  tokenizer = pipe.tokenizer
90
+ if model_name == "t5-small-finetuned":
91
+ prompt = f"generate question: context: {context} answer: {answer}"
92
+ else:
93
+ prompt = f"context: {context} answer: {answer}"
94
+ print(f"Prompt: {prompt}") # In ra prompt để kiểm tra
95
+
96
+ # Kiểm tra độ dài của prompt
97
+ encoded = tokenizer(prompt, return_tensors="pt", truncation=False, max_length=512)
98
+ input_ids = encoded["input_ids"]
99
+ if input_ids.size(1) > 512:
100
+ return "❌ Context quá dài (hơn 512 token). Xin nhập vào context ngắn hơn."
101
 
102
+ # Proceed with tokenization (with truncation if needed)
103
  encoded = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
104
  input_ids = encoded["input_ids"]
105
  attention_mask = encoded["attention_mask"]
 
130
  entities = list(set([ent.text for ent in doc.ents]))
131
  entities = [e for e in entities if len(e.strip().split()) <= 10]
132
 
133
+ # Nếu không tìm thấy thực thể, trả về thông báo lỗi trong status_message
134
  if not entities:
135
+ return gr.update(visible=True, value="❌ Không tìm thấy thực thể nào để sinh câu hỏi."), []
136
 
137
+ # Đảm bảo số câu hỏi không vượt quá số thực thể
138
  count = min(num_questions, len(entities))
139
  qa_list = []
140
 
141
  for i in range(count):
142
  answer = entities[i]
143
  question = generate_question(context, answer, model_choice)
144
+ # Nếu có lỗi (như context quá dài), trả về thông báo lỗi trong status_message
145
+ if question.startswith("❌") or question.startswith("Lỗi"):
146
+ return gr.update(visible=True, value=question), []
147
  answer = smart_capitalize(entities[i])
148
  qa = f"**{question}**\n<details><summary>Hiện câu trả lời</summary><p>{answer}</p></details>"
149
  qa_list.append(qa)
150
 
151
+ return gr.update(visible=False), qa_list
152
 
153
  # Tách phần phân tích context và cập nhật slider
154
+ def analyze_context(context, num_questions):
155
  doc = nlp(context)
156
  entities = list(set([ent.text for ent in doc.ents]))
157
  entities = [e for e in entities if len(e.strip().split()) <= 10]
 
167
  else:
168
  return (
169
  gr.update(visible=False),
170
+ gr.update(visible=True, maximum=entity_count, value=min(num_questions, entity_count), label=f"Số câu hỏi (Tối đa: {entity_count})"),
171
  gr.update(visible=True),
172
  gr.update(visible=True)
173
  )
174
 
175
  with gr.Blocks() as demo:
176
+ gr.Markdown("## Hệ thống sinh câu hỏi từ Context bằng ProphetNet + spaCy NER")
177
 
178
  with gr.Row():
179
  with gr.Column(scale=4):
 
183
  model_choice = gr.Dropdown(
184
  label="Chọn mô hình",
185
  choices=list(MODEL_PATHS.keys()),
186
+ value="t5-small-finetuned"
187
  )
188
+ num_input = gr.Slider(label="Số câu hỏi", minimum=1, maximum=20, value=3, step=1, visible=False)
189
  generate_btn = gr.Button("Sinh câu hỏi", visible=False)
190
 
191
  # Thông báo đang xử lý hoặc không tìm thấy
 
193
 
194
  # Kết quả hiển thị tại đây
195
  with gr.Column(visible=False) as output_container:
196
+ result_md_list = [gr.Markdown(visible=False) for _ in range(20)]
197
 
198
  # Xử lý khi bấm nút sinh câu hỏi
199
  def run_generation(context, num_questions, model_choice):
200
  start_time = time.time()
201
+ status_message, qa_list = generate_qa_list(context, num_questions, model_choice)
202
+
203
+ # Nếu có lỗi (status_message visible), trả về ngay lập tức
204
+ if status_message["visible"]:
205
+ return [status_message, gr.update(visible=False), gr.update(visible=False)] + [gr.update(visible=False) for _ in range(20)]
206
 
207
+ updates = []
208
+ for i in range(20):
209
  if i < len(qa_list):
210
  updates.append(gr.update(value=qa_list[i], visible=True))
211
  else:
 
215
  elapsed_msg = f"⏱️ Thời gian xử lý: {elapsed:.2f} giây"
216
  elapsed_md = gr.update(value=elapsed_msg, visible=True)
217
 
218
+ return [gr.update(visible=False), gr.update(visible=True), elapsed_md] + updates
219
 
220
  # Khi người dùng thay đổi context, tự động phân tích thực thể và cập nhật slider
221
  context_input.change(
222
  fn=analyze_context,
223
+ inputs=[context_input, num_input], # Thêm num_input vào inputs
224
  outputs=[status_message, num_input, generate_btn, elapsed_time_md]
225
  )
226
 
 
237
  outputs=[status_message, output_container, elapsed_time_md] + result_md_list
238
  )
239
 
240
+ demo.launch()
241
+
242
 
243
  # #/Users/trantieuman/anaconda3/bin/python /Users/trantieuman/Documents/NLP/project/demo.py
requirements.txt CHANGED
@@ -70,3 +70,4 @@ uvicorn==0.35.0
70
  websockets==15.0.1
71
  xxhash==3.5.0
72
  yarl==1.20.1
 
 
70
  websockets==15.0.1
71
  xxhash==3.5.0
72
  yarl==1.20.1
73
+ sentencepiece==0.2.0