bnithichanquyt commited on
Commit
d87ccd6
·
verified ·
1 Parent(s): 1694d2d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -17
app.py CHANGED
@@ -28,10 +28,8 @@ def fix_bartpho_output(text: str) -> str:
28
 
29
 
30
  # HuggingFace Transformers
31
- from transformers import (
32
- AutoTokenizer,
33
- AutoModelForSeq2SeqLM
34
- )
35
  #page config
36
  st.set_page_config(
37
  page_title="ViSum - Vietnamese News Summarization",
@@ -97,22 +95,29 @@ MODEL_ID = "OrdinaryAI/visum-qlora-5epochs"
97
 
98
  @st.cache_resource
99
  def load_model(model_id):
100
-
101
- # Load tokenizer
102
- tokenizer = AutoTokenizer.from_pretrained(model_id)
103
-
104
- # Load model
105
- model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
106
-
107
- # Kiểm tra có GPU không
 
 
 
 
 
 
 
 
 
 
 
108
  device = "cuda" if torch.cuda.is_available() else "cpu"
109
-
110
- # Đưa model sang device tương ứng
111
  model = model.to(device)
112
-
113
- # Chuyển sang chế độ inference
114
  model.eval()
115
-
116
  return tokenizer, model
117
 
118
 
 
28
 
29
 
30
  # HuggingFace Transformers
31
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
32
+ from peft import PeftModel, PeftConfig
 
 
33
  #page config
34
  st.set_page_config(
35
  page_title="ViSum - Vietnamese News Summarization",
 
95
 
96
  @st.cache_resource
97
  def load_model(model_id):
98
+ # Đọc config PEFT để biết model gốc là gì
99
+ peft_config = PeftConfig.from_pretrained(model_id)
100
+
101
+ # Load model gốc (vinai/bartpho-syllable)
102
+ base_model = AutoModelForSeq2SeqLM.from_pretrained(
103
+ peft_config.base_model_name_or_path
104
+ )
105
+
106
+ # Gắn trọng số QLoRA vào
107
+ model = PeftModel.from_pretrained(base_model, model_id)
108
+
109
+ # Merge vào model gốc → inference nhanh hơn
110
+ model = model.merge_and_unload()
111
+
112
+ # Load tokenizer từ model gốc
113
+ tokenizer = AutoTokenizer.from_pretrained(
114
+ peft_config.base_model_name_or_path
115
+ )
116
+
117
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
118
  model = model.to(device)
 
 
119
  model.eval()
120
+
121
  return tokenizer, model
122
 
123