Files changed (1) hide show
  1. app.py +25 -6
app.py CHANGED
@@ -1,11 +1,11 @@
1
 
2
  import gradio as gr
3
  import spacy
4
- from transformers import ProphetNetTokenizer, ProphetNetForConditionalGeneration, pipeline
5
  import torch
6
  import time
7
  import re
8
- import os # Đảm bảo bạn đã import os
9
 
10
  # Tải mô hình spaCy
11
  if not spacy.util.is_package("en_core_web_md"):
@@ -15,7 +15,8 @@ nlp = spacy.load("en_core_web_md")
15
  print("✅ Đã tải/nạp mô hình spaCy.")
16
  MODEL_PATHS = {
17
  "prophetnet2": "ManB2207540/prophetnet_SQuAD_1.1-2epoch_break",
18
- "prophetnet tieu chuan": "microsoft/prophetnet-large-uncased-squad-qg"
 
19
  }
20
 
21
  def load_pipeline(model_path):
@@ -30,17 +31,31 @@ def load_pipeline(model_path):
30
  device=0 if torch.cuda.is_available() else -1
31
  )
32
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  pipeline_cache = {}
34
 
35
  def get_pipeline(model_name):
36
  model_path = MODEL_PATHS[model_name]
37
  if model_name not in pipeline_cache:
38
- pipeline_cache[model_name] = load_pipeline(model_path)
 
 
 
39
  return pipeline_cache[model_name]
40
 
41
  # Tự viết hàm capitalize thông minh
42
 
43
-
44
  def smart_capitalize(text):
45
  # Giữ nguyên cách viết hoa phần còn lại, chỉ viết hoa chữ đầu nếu cần
46
  text = text.strip()
@@ -54,7 +69,11 @@ def smart_capitalize(text):
54
  def generate_question(context, answer, model_name):
55
  pipe = get_pipeline(model_name)
56
  tokenizer = pipe.tokenizer
57
- prompt = f"context: {context} answer: {answer}"
 
 
 
 
58
 
59
  # Cắt prompt nếu vượt quá giới hạn token
60
  encoded = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
 
1
 
2
  import gradio as gr
3
  import spacy
4
+ from transformers import ProphetNetTokenizer, ProphetNetForConditionalGeneration, pipeline, T5Tokenizer, T5ForConditionalGeneration
5
  import torch
6
  import time
7
  import re
8
+ import os
9
 
10
  # Tải mô hình spaCy
11
  if not spacy.util.is_package("en_core_web_md"):
 
15
  print("✅ Đã tải/nạp mô hình spaCy.")
16
  MODEL_PATHS = {
17
  "prophetnet2": "ManB2207540/prophetnet_SQuAD_1.1-2epoch_break",
18
+ "prophetnet tieu chuan": "microsoft/prophetnet-large-uncased-squad-qg",
19
+ "t5-small-finetuned": "tbtminh/t5-small-qg-finetuned"
20
  }
21
 
22
  def load_pipeline(model_path):
 
31
  device=0 if torch.cuda.is_available() else -1
32
  )
33
 
34
+ def load_t5_pipeline(model_path):
35
+ tokenizer = T5Tokenizer.from_pretrained(model_path)
36
+ model = T5ForConditionalGeneration.from_pretrained(model_path)
37
+ return pipeline(
38
+ "text2text-generation",
39
+ model=model,
40
+ tokenizer=tokenizer,
41
+ max_length=256,
42
+ num_return_sequences=1,
43
+ device=0 if torch.cuda.is_available() else -1
44
+ )
45
+
46
  pipeline_cache = {}
47
 
48
  def get_pipeline(model_name):
49
  model_path = MODEL_PATHS[model_name]
50
  if model_name not in pipeline_cache:
51
+ if model_name == "t5-small-finetuned":
52
+ pipeline_cache[model_name] = load_t5_pipeline(model_path)
53
+ else:
54
+ pipeline_cache[model_name] = load_pipeline(model_path)
55
  return pipeline_cache[model_name]
56
 
57
  # Tự viết hàm capitalize thông minh
58
 
 
59
  def smart_capitalize(text):
60
  # Giữ nguyên cách viết hoa phần còn lại, chỉ viết hoa chữ đầu nếu cần
61
  text = text.strip()
 
69
  def generate_question(context, answer, model_name):
70
  pipe = get_pipeline(model_name)
71
  tokenizer = pipe.tokenizer
72
+
73
+ if model_name == "t5-small-finetuned":
74
+ prompt = f"generate question: context: {context} answer: {answer}"
75
+ else:
76
+ prompt = f"context: {context} answer: {answer}"
77
 
78
  # Cắt prompt nếu vượt quá giới hạn token
79
  encoded = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)