Visum / app.py
bnithichanquyt's picture
Update app.py
d87ccd6 verified
#thư viện
import time
import torch
import streamlit as st
import re
def fix_bartpho_output(text: str) -> str:
"""
BARTpho syllable hay bị dính từ kiểu:
'Coquan' → 'Cơ quan', 'đốitượng' → 'đối tượng'
Hàm này thêm dấu cách trước chữ hoa giữa câu
và fix một số pattern hay gặp.
"""
# Thêm space trước chữ hoa nằm giữa từ thường
# Ví dụ: "CơQuan" → "Cơ Quan"
text = re.sub(r'([a-zđàáâãèéêìíòóôõùúýăắặấầẩẫậắằẳẵặ])'
r'([A-ZĐÀÁÂÃÈÉÊÌÍÒÓÔÕÙÚÝĂẮẶẤẦẨẪẬẮẰẲẴẶ])',
r'\1 \2', text)
# Fix dấu câu dính vào chữ: "vong.Cơ" → "vong. Cơ"
text = re.sub(r'([.!?,;:])([^\s])', r'\1 \2', text)
# Xóa khoảng trắng thừa
text = re.sub(r' +', ' ', text).strip()
return text
# HuggingFace Transformers
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from peft import PeftModel, PeftConfig
#page config
st.set_page_config(
page_title="ViSum - Vietnamese News Summarization",
page_icon="🇻🇳",
layout="wide"
)
#custom css
st.markdown("""
<style>
/* Font toàn app */
html, body, [class*="css"] {
font-family: 'Arial', sans-serif;
}
/* Nút chính */
.stButton > button[kind="primary"] {
background-color: #1a73e8;
color: white;
border: none;
border-radius: 10px;
padding: 0.6rem 1.5rem;
font-size: 16px;
font-weight: 600;
}
.stButton > button[kind="primary"]:hover {
background-color: #1557b0;
}
/* Text area */
.stTextArea textarea {
border-radius: 10px;
border: 1px solid #d0d0d0;
line-height: 1.6;
font-size: 15px;
}
/* Metric cards */
[data-testid="metric-container"] {
background-color: #f8f9fa;
border: 1px solid #e0e0e0;
padding: 15px;
border-radius: 12px;
}
/* Responsive cho mobile */
@media (max-width: 768px) {
h1 {
font-size: 1.8rem;
}
.stTextArea textarea {
font-size: 14px;
}
}
</style>
""", unsafe_allow_html=True)
#model config
MODEL_ID = "OrdinaryAI/visum-qlora-5epochs"
# =============================================================================
# LOAD MODEL
#
# @st.cache_resource:
# Streamlit chỉ load model 1 lần duy nhất
# Những lần sau dùng cache -> app nhanh hơn rất nhiều
# =============================================================================
@st.cache_resource
def load_model(model_id):
# Đọc config PEFT để biết model gốc là gì
peft_config = PeftConfig.from_pretrained(model_id)
# Load model gốc (vinai/bartpho-syllable)
base_model = AutoModelForSeq2SeqLM.from_pretrained(
peft_config.base_model_name_or_path
)
# Gắn trọng số QLoRA vào
model = PeftModel.from_pretrained(base_model, model_id)
# Merge vào model gốc → inference nhanh hơn
model = model.merge_and_unload()
# Load tokenizer từ model gốc
tokenizer = AutoTokenizer.from_pretrained(
peft_config.base_model_name_or_path
)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)
model.eval()
return tokenizer, model
# =============================================================================
# HÀM TÓM TẮT
#
# Pipeline:
# text
# -> tokenize
# -> model.generate()
# -> decode
# =============================================================================
def summarize_text(
text,
tokenizer,
model,
max_length=150,
min_length=50,
num_beams=4
):
# Lấy device hiện tại của model
device = next(model.parameters()).device
# Bắt đầu tính thời gian xử lý
start_time = time.time()
# =========================================================
# TOKENIZE
# Chuyển text -> tensor số
# =========================================================
inputs = tokenizer(
text,
return_tensors="pt",
max_length=1024,
truncation=True,
padding=True
).to(device)
# =========================================================
# GENERATE SUMMARY
# =========================================================
with torch.no_grad():
output_ids = model.generate(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
max_length=max_length,
min_length=min_length,
num_beams=num_beams,
early_stopping=True,
# Tránh lặp cụm từ
no_repeat_ngram_size=3
)
# =========================================================
# DECODE
# Token IDs -> text
# =========================================================
summary = tokenizer.decode(output_ids[0], skip_special_tokens=True)
# Thêm dòng này để fix dính từ
summary = fix_bartpho_output(summary)
elapsed_time = round(time.time() - start_time, 2)
return {
"summary": summary,
"time": elapsed_time
}
# =============================================================================
# SIDEBAR
# =============================================================================
with st.sidebar:
st.markdown("# 🇻🇳 ViSum")
st.caption("Vietnamese News Summarization")
st.markdown("---")
st.subheader("⚙️ Cài đặt")
# Slider độ dài tối đa
max_length = st.slider(
"Độ dài tối đa",
min_value=50,
max_value=500,
value=150,
step=10
)
# Slider độ dài tối thiểu
min_length = st.slider(
"Độ dài tối thiểu",
min_value=10,
max_value=200,
value=50,
step=10
)
# Beam search
num_beams = st.slider(
"Beam Search",
min_value=1,
max_value=8,
value=4,
step=1
)
st.markdown("---")
st.caption(f"Model: {MODEL_ID}")
st.caption("Ordinary-AI-Engineer")
# =============================================================================
# MAIN UI
# =============================================================================
st.title("ViSum - Hệ thống Tóm tắt Báo chí Tiếng Việt")
st.markdown("""
Dán bài báo hoặc đoạn văn tiếng Việt vào ô bên dưới,
sau đó nhấn **Tóm tắt** để AI tạo bản tóm tắt ngắn gọn.
""")
# =============================================================================
# INPUT TEXT AREA
# =============================================================================
input_text = st.text_area(
label="Văn bản gốc",
placeholder="Nhập nội dung tại đây...",
height=320
)
# =============================================================================
# BUTTON
# =============================================================================
col1, col2, col3 = st.columns([1, 2, 1])
with col2:
summarize_button = st.button(
"Tóm tắt",
type="primary",
use_container_width=True
)
# =============================================================================
# XỬ LÝ KHI USER NHẤN NÚT
# =============================================================================
if summarize_button:
# Xóa khoảng trắng thừa
clean_text = input_text.strip()
# =========================================================
# VALIDATION
# =========================================================
if not clean_text:
st.error("Vui lòng nhập văn bản!")
elif len(clean_text) < 100:
st.warning(
"Văn bản quá ngắn! "
"Kết quả tóm tắt có thể không chính xác."
)
else:
#load model
with st.spinner("Đang load model..."):
tokenizer, model = load_model(MODEL_ID)
# =====================================================
# SUMMARIZE
# =====================================================
with st.spinner("Đang tóm tắt văn bản, xin hãy chờ trong giấy lát."):
result = summarize_text(
text=clean_text,
tokenizer=tokenizer,
model=model,
max_length=max_length,
min_length=min_length,
num_beams=num_beams
)
summary = result["summary"]
elapsed = result["time"]
# =====================================================
# OUTPUT
# =====================================================
st.success("Tóm tắt hoàn thành!")
st.text_area(
label="Kết quả tóm tắt: ",
value=summary,
height=220
)
# =====================================================
# METRICS
# =====================================================
original_words = len(clean_text.split())
summary_words = len(summary.split())
reduction_percent = round(
(1 - summary_words / original_words) * 100,
1
)
m1, m2, m3, m4 = st.columns(4)
m1.metric(
"Thời gian",
f"{elapsed}s"
)
m2.metric(
"Từ gốc",
original_words
)
m3.metric(
"Từ tóm tắt",
summary_words
)
m4.metric(
"Rút gọn",
f"{reduction_percent}%"
)
# =============================================================================
# FOOTER
# =============================================================================
st.markdown("---")
st.caption(
"ViSum • Vietnamese News Summarization System • "
"Powered by Hugging Face Transformers"
)