ManB2207540 commited on
Commit
e9c022b
·
1 Parent(s): 6f22fac

modify app.py for load models pipeline function

Browse files
Files changed (1) hide show
  1. app.py +50 -9
app.py CHANGED
@@ -2,7 +2,7 @@
2
  import gradio as gr
3
  import spacy
4
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline, AutoConfig
5
- from transformers import ProphetNetTokenizer, ProphetNetForConditionalGeneration, T5Tokenizer, T5ForConditionalGeneration
6
  import torch
7
  import time
8
  import re
@@ -16,8 +16,7 @@ nlp = spacy.load("en_core_web_md")
16
 
17
  # Đường dẫn mô hình
18
  MODEL_PATHS = {
19
- "prophetnet-finetuned": "ManB2207540/prophetnet_SQuAD_1.1-2epoch_break",
20
- "prophetnet tieu chuan": "microsoft/prophetnet-large-uncased-squad-qg",
21
  "bart-finetuned": "mghan3624/bart_qg_finetune_squad",
22
  "t5-small-finetuned": "tbtminh/t5-small-qg-finetuned"
23
  }
@@ -39,14 +38,48 @@ def load_t5_pipeline(model_path):
39
  print(f"Failed to load T5 pipeline for {model_path}: {e}")
40
  return None
41
 
42
- # Hàm tải mô hình chung
43
- def load_pipeline(model_path):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  try:
45
  config = AutoConfig.from_pretrained(model_path)
46
  if getattr(config, "early_stopping", None) is None:
47
  config.early_stopping = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  tokenizer = AutoTokenizer.from_pretrained(model_path)
49
- model = AutoModelForSeq2SeqLM.from_pretrained(model_path, config=config)
50
  return pipeline(
51
  "text2text-generation",
52
  model=model,
@@ -67,6 +100,10 @@ def get_pipeline(model_name):
67
  if model_name not in pipeline_cache:
68
  if model_name == "t5-small-finetuned":
69
  pipeline_cache[model_name] = load_t5_pipeline(model_path)
 
 
 
 
70
  else:
71
  pipeline_cache[model_name] = load_pipeline(model_path)
72
  return pipeline_cache[model_name]
@@ -89,6 +126,10 @@ def generate_question(context, answer, model_name):
89
  tokenizer = pipe.tokenizer
90
  if model_name == "t5-small-finetuned":
91
  prompt = f"generate question: context: {context} answer: {answer}"
 
 
 
 
92
  else:
93
  prompt = f"context: {context} answer: {answer}"
94
  print(f"Prompt: {prompt}") # In ra prompt để kiểm tra
@@ -97,7 +138,7 @@ def generate_question(context, answer, model_name):
97
  encoded = tokenizer(prompt, return_tensors="pt", truncation=False, max_length=512)
98
  input_ids = encoded["input_ids"]
99
  if input_ids.size(1) > 512:
100
- return "❌ Context quá dài (hơn 512 token). Xin nhập vào context ngắn hơn."
101
 
102
  # Proceed with tokenization (with truncation if needed)
103
  encoded = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
@@ -173,11 +214,11 @@ def analyze_context(context, num_questions):
173
  )
174
 
175
  with gr.Blocks() as demo:
176
- gr.Markdown("## Hệ thống sinh câu hỏi từ Context bằng mô hình Encoder-Decoder + spaCy NER")
177
 
178
  with gr.Row():
179
  with gr.Column(scale=4):
180
- context_input = gr.Textbox(label="Nhập Context", lines=15, placeholder="Nhập đoạn văn bản...")
181
  elapsed_time_md = gr.Markdown(visible=False)
182
  with gr.Column(scale=1):
183
  model_choice = gr.Dropdown(
 
2
  import gradio as gr
3
  import spacy
4
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline, AutoConfig
5
+ from transformers import ProphetNetTokenizer, ProphetNetForConditionalGeneration, T5Tokenizer, T5ForConditionalGeneration, BartTokenizer, BartForConditionalGeneration
6
  import torch
7
  import time
8
  import re
 
16
 
17
  # Đường dẫn mô hình
18
  MODEL_PATHS = {
19
+ "prophetnet-large-uncased-finetuned": "ManB2207540/prophetnet_SQuAD_1.1-2epoch_break",
 
20
  "bart-finetuned": "mghan3624/bart_qg_finetune_squad",
21
  "t5-small-finetuned": "tbtminh/t5-small-qg-finetuned"
22
  }
 
38
  print(f"Failed to load T5 pipeline for {model_path}: {e}")
39
  return None
40
 
41
+ # Ham tải mô hình ProphetNet
42
+ def load_prophetnet_pipeline(model_path):
43
+ try:
44
+ tokenizer = ProphetNetTokenizer.from_pretrained(model_path)
45
+ model = ProphetNetForConditionalGeneration.from_pretrained(model_path)
46
+ return pipeline(
47
+ "text2text-generation",
48
+ model=model,
49
+ tokenizer=tokenizer,
50
+ max_length=256,
51
+ num_return_sequences=1,
52
+ device=0 if torch.cuda.is_available() else -1
53
+ )
54
+ except Exception as e:
55
+ print(f"Failed to load ProphetNet pipeline for {model_path}: {e}")
56
+ return None
57
+
58
+ # Hàm tải mô hình Bart
59
+ def load_bart_pipeline(model_path):
60
  try:
61
  config = AutoConfig.from_pretrained(model_path)
62
  if getattr(config, "early_stopping", None) is None:
63
  config.early_stopping = False
64
+ tokenizer = BartTokenizer.from_pretrained(model_path)
65
+ model = BartForConditionalGeneration.from_pretrained(model_path, config=config)
66
+ return pipeline(
67
+ "text2text-generation",
68
+ model=model,
69
+ tokenizer=tokenizer,
70
+ max_length=256,
71
+ num_return_sequences=1,
72
+ device=0 if torch.cuda.is_available() else -1
73
+ )
74
+ except Exception as e:
75
+ print(f"Failed to load Bart pipeline for {model_path}: {e}")
76
+ return None
77
+
78
+ # Hàm tải mô hình chung
79
+ def load_pipeline(model_path):
80
+ try:
81
  tokenizer = AutoTokenizer.from_pretrained(model_path)
82
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
83
  return pipeline(
84
  "text2text-generation",
85
  model=model,
 
100
  if model_name not in pipeline_cache:
101
  if model_name == "t5-small-finetuned":
102
  pipeline_cache[model_name] = load_t5_pipeline(model_path)
103
+ elif model_name == "prophetnet-large-uncased-finetuned":
104
+ pipeline_cache[model_name] = load_prophetnet_pipeline(model_path)
105
+ elif model_name == "bart-finetuned":
106
+ pipeline_cache[model_name] = load_bart_pipeline(model_path)
107
  else:
108
  pipeline_cache[model_name] = load_pipeline(model_path)
109
  return pipeline_cache[model_name]
 
126
  tokenizer = pipe.tokenizer
127
  if model_name == "t5-small-finetuned":
128
  prompt = f"generate question: context: {context} answer: {answer}"
129
+ elif model_name == "prophetnet-large-uncased-finetuned":
130
+ prompt = f"context: {context} answer: {answer}"
131
+ elif model_name == "bart-finetuned":
132
+ prompt = f"context: {context} answer: {answer}"
133
  else:
134
  prompt = f"context: {context} answer: {answer}"
135
  print(f"Prompt: {prompt}") # In ra prompt để kiểm tra
 
138
  encoded = tokenizer(prompt, return_tensors="pt", truncation=False, max_length=512)
139
  input_ids = encoded["input_ids"]
140
  if input_ids.size(1) > 512:
141
+ return "❌ Văn bản quá dài. Xin nhập vào văn bản ngắn hơn." # (hơn 512 token)
142
 
143
  # Proceed with tokenization (with truncation if needed)
144
  encoded = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
 
214
  )
215
 
216
  with gr.Blocks() as demo:
217
+ gr.Markdown("## Hệ thống sinh câu hỏi")
218
 
219
  with gr.Row():
220
  with gr.Column(scale=4):
221
+ context_input = gr.Textbox(label="Nhập văn bản", lines=15, placeholder="Nhập đoạn văn bản...")
222
  elapsed_time_md = gr.Markdown(visible=False)
223
  with gr.Column(scale=1):
224
  model_choice = gr.Dropdown(