bnithichanquyt commited on
Commit
e732964
·
verified ·
1 Parent(s): 41647c5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +372 -372
app.py CHANGED
@@ -1,372 +1,372 @@
1
- #thư viện
2
- import time
3
- import torch
4
- import streamlit as st
5
-
6
- # HuggingFace Transformers
7
- from transformers import (
8
- AutoTokenizer,
9
- AutoModelForSeq2SeqLM
10
- )
11
- #page config
12
- st.set_page_config(
13
- page_title="ViSum - Vietnamese Summarization",
14
- page_icon="📰",
15
- layout="wide"
16
- )
17
- #custom css
18
- st.markdown("""
19
- <style>
20
-
21
- /* Font toàn app */
22
- html, body, [class*="css"] {
23
- font-family: 'Arial', sans-serif;
24
- }
25
-
26
- /* Nút chính */
27
- .stButton > button[kind="primary"] {
28
- background-color: #1a73e8;
29
- color: white;
30
- border: none;
31
- border-radius: 10px;
32
- padding: 0.6rem 1.5rem;
33
- font-size: 16px;
34
- font-weight: 600;
35
- }
36
-
37
- .stButton > button[kind="primary"]:hover {
38
- background-color: #1557b0;
39
- }
40
-
41
- /* Text area */
42
- .stTextArea textarea {
43
- border-radius: 10px;
44
- border: 1px solid #d0d0d0;
45
- line-height: 1.6;
46
- font-size: 15px;
47
- }
48
-
49
- /* Metric cards */
50
- [data-testid="metric-container"] {
51
- background-color: #f8f9fa;
52
- border: 1px solid #e0e0e0;
53
- padding: 15px;
54
- border-radius: 12px;
55
- }
56
-
57
- /* Responsive cho mobile */
58
- @media (max-width: 768px) {
59
- h1 {
60
- font-size: 1.8rem;
61
- }
62
-
63
- .stTextArea textarea {
64
- font-size: 14px;
65
- }
66
- }
67
-
68
- </style>
69
- """, unsafe_allow_html=True)
70
- #model config
71
- MODEL_ID = "VietAI/vit5-base-vietnews-summarization"
72
-
73
-
74
- # =============================================================================
75
- # LOAD MODEL
76
- #
77
- # @st.cache_resource:
78
- # Streamlit chỉ load model 1 lần duy nhất
79
- # Những lần sau dùng cache -> app nhanh hơn rất nhiều
80
- # =============================================================================
81
-
82
- @st.cache_resource
83
- def load_model(model_id):
84
-
85
- # Load tokenizer
86
- tokenizer = AutoTokenizer.from_pretrained(model_id)
87
-
88
- # Load model
89
- model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
90
-
91
- # Kiểm tra có GPU không
92
- device = "cuda" if torch.cuda.is_available() else "cpu"
93
-
94
- # Đưa model sang device tương ứng
95
- model = model.to(device)
96
-
97
- # Chuyển sang chế độ inference
98
- model.eval()
99
-
100
- return tokenizer, model
101
-
102
-
103
- # =============================================================================
104
- # HÀM TÓM TẮT
105
- #
106
- # Pipeline:
107
- # text
108
- # -> tokenize
109
- # -> model.generate()
110
- # -> decode
111
- # =============================================================================
112
-
113
- def summarize_text(
114
- text,
115
- tokenizer,
116
- model,
117
- max_length=150,
118
- min_length=50,
119
- num_beams=4
120
- ):
121
-
122
- # Lấy device hiện tại của model
123
- device = next(model.parameters()).device
124
-
125
- # Bắt đầu tính thời gian xử lý
126
- start_time = time.time()
127
-
128
- # =========================================================
129
- # TOKENIZE
130
- # Chuyển text -> tensor số
131
- # =========================================================
132
-
133
- inputs = tokenizer(
134
- text,
135
- return_tensors="pt",
136
- max_length=1024,
137
- truncation=True,
138
- padding=True
139
- ).to(device)
140
-
141
- # =========================================================
142
- # GENERATE SUMMARY
143
- # =========================================================
144
-
145
- with torch.no_grad():
146
-
147
- output_ids = model.generate(
148
- input_ids=inputs["input_ids"],
149
- attention_mask=inputs["attention_mask"],
150
-
151
- max_length=max_length,
152
- min_length=min_length,
153
-
154
- num_beams=num_beams,
155
-
156
- early_stopping=True,
157
-
158
- # Tránh lặp cụm từ
159
- no_repeat_ngram_size=3
160
- )
161
-
162
- # =========================================================
163
- # DECODE
164
- # Token IDs -> text
165
- # =========================================================
166
-
167
- summary = tokenizer.decode(
168
- output_ids[0],
169
- skip_special_tokens=True
170
- )
171
-
172
- # Tổng thời gian xử lý
173
- elapsed_time = round(time.time() - start_time, 2)
174
-
175
- return {
176
- "summary": summary,
177
- "time": elapsed_time
178
- }
179
-
180
-
181
- # =============================================================================
182
- # SIDEBAR
183
- # =============================================================================
184
-
185
- with st.sidebar:
186
-
187
- st.markdown("# 📰 ViSum")
188
-
189
- st.caption("Vietnamese Text Summarization")
190
-
191
- st.markdown("---")
192
-
193
- st.subheader("⚙️ Cài đặt")
194
-
195
- # Slider độ dài tối đa
196
- max_length = st.slider(
197
- "Độ dài tối đa",
198
- min_value=50,
199
- max_value=500,
200
- value=150,
201
- step=10
202
- )
203
-
204
- # Slider độ dài tối thiểu
205
- min_length = st.slider(
206
- "Độ dài tối thiểu",
207
- min_value=10,
208
- max_value=200,
209
- value=50,
210
- step=10
211
- )
212
-
213
- # Beam search
214
- num_beams = st.slider(
215
- "Beam Search",
216
- min_value=1,
217
- max_value=8,
218
- value=4,
219
- step=1
220
- )
221
-
222
- st.markdown("---")
223
-
224
- st.caption(f"Model: {MODEL_ID}")
225
-
226
- st.caption("Ordinary-AI-Engineer")
227
-
228
-
229
- # =============================================================================
230
- # MAIN UI
231
- # =============================================================================
232
-
233
- st.title("ViSum — Tóm tắt văn bản tiếng Việt")
234
-
235
- st.markdown("""
236
- Dán bài báo hoặc đoạn văn tiếng Việt vào ô bên dưới,
237
- sau đó nhấn **Tóm tắt** để AI tạo bản tóm tắt ngắn gọn.
238
- """)
239
-
240
-
241
- # =============================================================================
242
- # INPUT TEXT AREA
243
- # =============================================================================
244
-
245
- input_text = st.text_area(
246
- label="Văn bản gốc",
247
- placeholder="Nhập nội dung tại đây...",
248
- height=320
249
- )
250
-
251
-
252
- # =============================================================================
253
- # BUTTON
254
- # =============================================================================
255
-
256
- col1, col2, col3 = st.columns([1, 2, 1])
257
-
258
- with col2:
259
-
260
- summarize_button = st.button(
261
- "Tóm tắt",
262
- type="primary",
263
- use_container_width=True
264
- )
265
-
266
-
267
- # =============================================================================
268
- # XỬ LÝ KHI USER NHẤN NÚT
269
- # =============================================================================
270
-
271
- if summarize_button:
272
-
273
- # Xóa khoảng trắng thừa
274
- clean_text = input_text.strip()
275
-
276
- # =========================================================
277
- # VALIDATION
278
- # =========================================================
279
-
280
- if not clean_text:
281
-
282
- st.error("Vui lòng nhập văn bản!")
283
-
284
- elif len(clean_text) < 100:
285
-
286
- st.warning(
287
- "Văn bản quá ngắn! "
288
- "Kết quả tóm tắt có thể không chính xác."
289
- )
290
-
291
- else:
292
- #load model
293
- with st.spinner("Đang load model..."):
294
-
295
- tokenizer, model = load_model(MODEL_ID)
296
-
297
- # =====================================================
298
- # SUMMARIZE
299
- # =====================================================
300
-
301
- with st.spinner("Đang tóm tắt văn bản, xin hãy chờ trong giấy lát."):
302
-
303
- result = summarize_text(
304
- text=clean_text,
305
- tokenizer=tokenizer,
306
- model=model,
307
- max_length=max_length,
308
- min_length=min_length,
309
- num_beams=num_beams
310
- )
311
-
312
- summary = result["summary"]
313
- elapsed = result["time"]
314
-
315
- # =====================================================
316
- # OUTPUT
317
- # =====================================================
318
-
319
- st.success("Tóm tắt hoàn thành!")
320
-
321
- st.text_area(
322
- label="Kết quả tóm tắt: ",
323
- value=summary,
324
- height=220
325
- )
326
-
327
- # =====================================================
328
- # METRICS
329
- # =====================================================
330
-
331
- original_words = len(clean_text.split())
332
-
333
- summary_words = len(summary.split())
334
-
335
- reduction_percent = round(
336
- (1 - summary_words / original_words) * 100,
337
- 1
338
- )
339
-
340
- m1, m2, m3, m4 = st.columns(4)
341
-
342
- m1.metric(
343
- "Thời gian",
344
- f"{elapsed}s"
345
- )
346
-
347
- m2.metric(
348
- "Từ gốc",
349
- original_words
350
- )
351
-
352
- m3.metric(
353
- "Từ tóm tắt",
354
- summary_words
355
- )
356
-
357
- m4.metric(
358
- "Rút gọn",
359
- f"{reduction_percent}%"
360
- )
361
-
362
-
363
- # =============================================================================
364
- # FOOTER
365
- # =============================================================================
366
-
367
- st.markdown("---")
368
-
369
- st.caption(
370
- "ViSum • Vietnamese Summarization System • "
371
- "Powered by Hugging Face Transformers"
372
- )
 
1
+ #thư viện
2
+ import time
3
+ import torch
4
+ import streamlit as st
5
+
6
+ # HuggingFace Transformers
7
+ from transformers import (
8
+ AutoTokenizer,
9
+ AutoModelForSeq2SeqLM
10
+ )
11
+ #page config
12
+ st.set_page_config(
13
+ page_title="ViSum - Vietnamese Summarization",
14
+ page_icon="📰",
15
+ layout="wide"
16
+ )
17
+ #custom css
18
+ st.markdown("""
19
+ <style>
20
+
21
+ /* Font toàn app */
22
+ html, body, [class*="css"] {
23
+ font-family: 'Arial', sans-serif;
24
+ }
25
+
26
+ /* Nút chính */
27
+ .stButton > button[kind="primary"] {
28
+ background-color: #1a73e8;
29
+ color: white;
30
+ border: none;
31
+ border-radius: 10px;
32
+ padding: 0.6rem 1.5rem;
33
+ font-size: 16px;
34
+ font-weight: 600;
35
+ }
36
+
37
+ .stButton > button[kind="primary"]:hover {
38
+ background-color: #1557b0;
39
+ }
40
+
41
+ /* Text area */
42
+ .stTextArea textarea {
43
+ border-radius: 10px;
44
+ border: 1px solid #d0d0d0;
45
+ line-height: 1.6;
46
+ font-size: 15px;
47
+ }
48
+
49
+ /* Metric cards */
50
+ [data-testid="metric-container"] {
51
+ background-color: #f8f9fa;
52
+ border: 1px solid #e0e0e0;
53
+ padding: 15px;
54
+ border-radius: 12px;
55
+ }
56
+
57
+ /* Responsive cho mobile */
58
+ @media (max-width: 768px) {
59
+ h1 {
60
+ font-size: 1.8rem;
61
+ }
62
+
63
+ .stTextArea textarea {
64
+ font-size: 14px;
65
+ }
66
+ }
67
+
68
+ </style>
69
+ """, unsafe_allow_html=True)
70
+ #model config
71
+ MODEL_ID = "OrdinaryAI/visum-qlora-5epochs"
72
+
73
+
74
+ # =============================================================================
75
+ # LOAD MODEL
76
+ #
77
+ # @st.cache_resource:
78
+ # Streamlit chỉ load model 1 lần duy nhất
79
+ # Những lần sau dùng cache -> app nhanh hơn rất nhiều
80
+ # =============================================================================
81
+
82
+ @st.cache_resource
83
+ def load_model(model_id):
84
+
85
+ # Load tokenizer
86
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
87
+
88
+ # Load model
89
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
90
+
91
+ # Kiểm tra có GPU không
92
+ device = "cuda" if torch.cuda.is_available() else "cpu"
93
+
94
+ # Đưa model sang device tương ứng
95
+ model = model.to(device)
96
+
97
+ # Chuyển sang chế độ inference
98
+ model.eval()
99
+
100
+ return tokenizer, model
101
+
102
+
103
+ # =============================================================================
104
+ # HÀM TÓM TẮT
105
+ #
106
+ # Pipeline:
107
+ # text
108
+ # -> tokenize
109
+ # -> model.generate()
110
+ # -> decode
111
+ # =============================================================================
112
+
113
+ def summarize_text(
114
+ text,
115
+ tokenizer,
116
+ model,
117
+ max_length=150,
118
+ min_length=50,
119
+ num_beams=4
120
+ ):
121
+
122
+ # Lấy device hiện tại của model
123
+ device = next(model.parameters()).device
124
+
125
+ # Bắt đầu tính thời gian xử lý
126
+ start_time = time.time()
127
+
128
+ # =========================================================
129
+ # TOKENIZE
130
+ # Chuyển text -> tensor số
131
+ # =========================================================
132
+
133
+ inputs = tokenizer(
134
+ text,
135
+ return_tensors="pt",
136
+ max_length=1024,
137
+ truncation=True,
138
+ padding=True
139
+ ).to(device)
140
+
141
+ # =========================================================
142
+ # GENERATE SUMMARY
143
+ # =========================================================
144
+
145
+ with torch.no_grad():
146
+
147
+ output_ids = model.generate(
148
+ input_ids=inputs["input_ids"],
149
+ attention_mask=inputs["attention_mask"],
150
+
151
+ max_length=max_length,
152
+ min_length=min_length,
153
+
154
+ num_beams=num_beams,
155
+
156
+ early_stopping=True,
157
+
158
+ # Tránh lặp cụm từ
159
+ no_repeat_ngram_size=3
160
+ )
161
+
162
+ # =========================================================
163
+ # DECODE
164
+ # Token IDs -> text
165
+ # =========================================================
166
+
167
+ summary = tokenizer.decode(
168
+ output_ids[0],
169
+ skip_special_tokens=True
170
+ )
171
+
172
+ # Tổng thời gian xử lý
173
+ elapsed_time = round(time.time() - start_time, 2)
174
+
175
+ return {
176
+ "summary": summary,
177
+ "time": elapsed_time
178
+ }
179
+
180
+
181
+ # =============================================================================
182
+ # SIDEBAR
183
+ # =============================================================================
184
+
185
+ with st.sidebar:
186
+
187
+ st.markdown("# 📰 ViSum")
188
+
189
+ st.caption("Vietnamese Text Summarization")
190
+
191
+ st.markdown("---")
192
+
193
+ st.subheader("⚙️ Cài đặt")
194
+
195
+ # Slider độ dài tối đa
196
+ max_length = st.slider(
197
+ "Độ dài tối đa",
198
+ min_value=50,
199
+ max_value=500,
200
+ value=150,
201
+ step=10
202
+ )
203
+
204
+ # Slider độ dài tối thiểu
205
+ min_length = st.slider(
206
+ "Độ dài tối thiểu",
207
+ min_value=10,
208
+ max_value=200,
209
+ value=50,
210
+ step=10
211
+ )
212
+
213
+ # Beam search
214
+ num_beams = st.slider(
215
+ "Beam Search",
216
+ min_value=1,
217
+ max_value=8,
218
+ value=4,
219
+ step=1
220
+ )
221
+
222
+ st.markdown("---")
223
+
224
+ st.caption(f"Model: {MODEL_ID}")
225
+
226
+ st.caption("Ordinary-AI-Engineer")
227
+
228
+
229
+ # =============================================================================
230
+ # MAIN UI
231
+ # =============================================================================
232
+
233
+ st.title("ViSum — Tóm tắt văn bản tiếng Việt")
234
+
235
+ st.markdown("""
236
+ Dán bài báo hoặc đoạn văn tiếng Việt vào ô bên dưới,
237
+ sau đó nhấn **Tóm tắt** để AI tạo bản tóm tắt ngắn gọn.
238
+ """)
239
+
240
+
241
+ # =============================================================================
242
+ # INPUT TEXT AREA
243
+ # =============================================================================
244
+
245
+ input_text = st.text_area(
246
+ label="Văn bản gốc",
247
+ placeholder="Nhập nội dung tại đây...",
248
+ height=320
249
+ )
250
+
251
+
252
+ # =============================================================================
253
+ # BUTTON
254
+ # =============================================================================
255
+
256
+ col1, col2, col3 = st.columns([1, 2, 1])
257
+
258
+ with col2:
259
+
260
+ summarize_button = st.button(
261
+ "Tóm tắt",
262
+ type="primary",
263
+ use_container_width=True
264
+ )
265
+
266
+
267
+ # =============================================================================
268
+ # XỬ LÝ KHI USER NHẤN NÚT
269
+ # =============================================================================
270
+
271
+ if summarize_button:
272
+
273
+ # Xóa khoảng trắng thừa
274
+ clean_text = input_text.strip()
275
+
276
+ # =========================================================
277
+ # VALIDATION
278
+ # =========================================================
279
+
280
+ if not clean_text:
281
+
282
+ st.error("Vui lòng nhập văn bản!")
283
+
284
+ elif len(clean_text) < 100:
285
+
286
+ st.warning(
287
+ "Văn bản quá ngắn! "
288
+ "Kết quả tóm tắt có thể không chính xác."
289
+ )
290
+
291
+ else:
292
+ #load model
293
+ with st.spinner("Đang load model..."):
294
+
295
+ tokenizer, model = load_model(MODEL_ID)
296
+
297
+ # =====================================================
298
+ # SUMMARIZE
299
+ # =====================================================
300
+
301
+ with st.spinner("Đang tóm tắt văn bản, xin hãy chờ trong giấy lát."):
302
+
303
+ result = summarize_text(
304
+ text=clean_text,
305
+ tokenizer=tokenizer,
306
+ model=model,
307
+ max_length=max_length,
308
+ min_length=min_length,
309
+ num_beams=num_beams
310
+ )
311
+
312
+ summary = result["summary"]
313
+ elapsed = result["time"]
314
+
315
+ # =====================================================
316
+ # OUTPUT
317
+ # =====================================================
318
+
319
+ st.success("Tóm tắt hoàn thành!")
320
+
321
+ st.text_area(
322
+ label="Kết quả tóm tắt: ",
323
+ value=summary,
324
+ height=220
325
+ )
326
+
327
+ # =====================================================
328
+ # METRICS
329
+ # =====================================================
330
+
331
+ original_words = len(clean_text.split())
332
+
333
+ summary_words = len(summary.split())
334
+
335
+ reduction_percent = round(
336
+ (1 - summary_words / original_words) * 100,
337
+ 1
338
+ )
339
+
340
+ m1, m2, m3, m4 = st.columns(4)
341
+
342
+ m1.metric(
343
+ "Thời gian",
344
+ f"{elapsed}s"
345
+ )
346
+
347
+ m2.metric(
348
+ "Từ gốc",
349
+ original_words
350
+ )
351
+
352
+ m3.metric(
353
+ "Từ tóm tắt",
354
+ summary_words
355
+ )
356
+
357
+ m4.metric(
358
+ "Rút gọn",
359
+ f"{reduction_percent}%"
360
+ )
361
+
362
+
363
+ # =============================================================================
364
+ # FOOTER
365
+ # =============================================================================
366
+
367
+ st.markdown("---")
368
+
369
+ st.caption(
370
+ "ViSum • Vietnamese Summarization System • "
371
+ "Powered by Hugging Face Transformers"
372
+ )