#thư viện import time import torch import streamlit as st import re def fix_bartpho_output(text: str) -> str: """ BARTpho syllable hay bị dính từ kiểu: 'Coquan' → 'Cơ quan', 'đốitượng' → 'đối tượng' Hàm này thêm dấu cách trước chữ hoa giữa câu và fix một số pattern hay gặp. """ # Thêm space trước chữ hoa nằm giữa từ thường # Ví dụ: "CơQuan" → "Cơ Quan" text = re.sub(r'([a-zđàáâãèéêìíòóôõùúýăắặấầẩẫậắằẳẵặ])' r'([A-ZĐÀÁÂÃÈÉÊÌÍÒÓÔÕÙÚÝĂẮẶẤẦẨẪẬẮẰẲẴẶ])', r'\1 \2', text) # Fix dấu câu dính vào chữ: "vong.Cơ" → "vong. Cơ" text = re.sub(r'([.!?,;:])([^\s])', r'\1 \2', text) # Xóa khoảng trắng thừa text = re.sub(r' +', ' ', text).strip() return text # HuggingFace Transformers from transformers import AutoTokenizer, AutoModelForSeq2SeqLM from peft import PeftModel, PeftConfig #page config st.set_page_config( page_title="ViSum - Vietnamese News Summarization", page_icon="🇻🇳", layout="wide" ) #custom css st.markdown(""" """, unsafe_allow_html=True) #model config MODEL_ID = "OrdinaryAI/visum-qlora-5epochs" # ============================================================================= # LOAD MODEL # # @st.cache_resource: # Streamlit chỉ load model 1 lần duy nhất # Những lần sau dùng cache -> app nhanh hơn rất nhiều # ============================================================================= @st.cache_resource def load_model(model_id): # Đọc config PEFT để biết model gốc là gì peft_config = PeftConfig.from_pretrained(model_id) # Load model gốc (vinai/bartpho-syllable) base_model = AutoModelForSeq2SeqLM.from_pretrained( peft_config.base_model_name_or_path ) # Gắn trọng số QLoRA vào model = PeftModel.from_pretrained(base_model, model_id) # Merge vào model gốc → inference nhanh hơn model = model.merge_and_unload() # Load tokenizer từ model gốc tokenizer = AutoTokenizer.from_pretrained( peft_config.base_model_name_or_path ) device = "cuda" if torch.cuda.is_available() else "cpu" model = model.to(device) model.eval() return tokenizer, model # ============================================================================= # HÀM TÓM TẮT # # Pipeline: # text # -> tokenize # -> model.generate() # -> decode # ============================================================================= def summarize_text( text, tokenizer, model, max_length=150, min_length=50, num_beams=4 ): # Lấy device hiện tại của model device = next(model.parameters()).device # Bắt đầu tính thời gian xử lý start_time = time.time() # ========================================================= # TOKENIZE # Chuyển text -> tensor số # ========================================================= inputs = tokenizer( text, return_tensors="pt", max_length=1024, truncation=True, padding=True ).to(device) # ========================================================= # GENERATE SUMMARY # ========================================================= with torch.no_grad(): output_ids = model.generate( input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], max_length=max_length, min_length=min_length, num_beams=num_beams, early_stopping=True, # Tránh lặp cụm từ no_repeat_ngram_size=3 ) # ========================================================= # DECODE # Token IDs -> text # ========================================================= summary = tokenizer.decode(output_ids[0], skip_special_tokens=True) # Thêm dòng này để fix dính từ summary = fix_bartpho_output(summary) elapsed_time = round(time.time() - start_time, 2) return { "summary": summary, "time": elapsed_time } # ============================================================================= # SIDEBAR # ============================================================================= with st.sidebar: st.markdown("# 🇻🇳 ViSum") st.caption("Vietnamese News Summarization") st.markdown("---") st.subheader("⚙️ Cài đặt") # Slider độ dài tối đa max_length = st.slider( "Độ dài tối đa", min_value=50, max_value=500, value=150, step=10 ) # Slider độ dài tối thiểu min_length = st.slider( "Độ dài tối thiểu", min_value=10, max_value=200, value=50, step=10 ) # Beam search num_beams = st.slider( "Beam Search", min_value=1, max_value=8, value=4, step=1 ) st.markdown("---") st.caption(f"Model: {MODEL_ID}") st.caption("Ordinary-AI-Engineer") # ============================================================================= # MAIN UI # ============================================================================= st.title("ViSum - Hệ thống Tóm tắt Báo chí Tiếng Việt") st.markdown(""" Dán bài báo hoặc đoạn văn tiếng Việt vào ô bên dưới, sau đó nhấn **Tóm tắt** để AI tạo bản tóm tắt ngắn gọn. """) # ============================================================================= # INPUT TEXT AREA # ============================================================================= input_text = st.text_area( label="Văn bản gốc", placeholder="Nhập nội dung tại đây...", height=320 ) # ============================================================================= # BUTTON # ============================================================================= col1, col2, col3 = st.columns([1, 2, 1]) with col2: summarize_button = st.button( "Tóm tắt", type="primary", use_container_width=True ) # ============================================================================= # XỬ LÝ KHI USER NHẤN NÚT # ============================================================================= if summarize_button: # Xóa khoảng trắng thừa clean_text = input_text.strip() # ========================================================= # VALIDATION # ========================================================= if not clean_text: st.error("Vui lòng nhập văn bản!") elif len(clean_text) < 100: st.warning( "Văn bản quá ngắn! " "Kết quả tóm tắt có thể không chính xác." ) else: #load model with st.spinner("Đang load model..."): tokenizer, model = load_model(MODEL_ID) # ===================================================== # SUMMARIZE # ===================================================== with st.spinner("Đang tóm tắt văn bản, xin hãy chờ trong giấy lát."): result = summarize_text( text=clean_text, tokenizer=tokenizer, model=model, max_length=max_length, min_length=min_length, num_beams=num_beams ) summary = result["summary"] elapsed = result["time"] # ===================================================== # OUTPUT # ===================================================== st.success("Tóm tắt hoàn thành!") st.text_area( label="Kết quả tóm tắt: ", value=summary, height=220 ) # ===================================================== # METRICS # ===================================================== original_words = len(clean_text.split()) summary_words = len(summary.split()) reduction_percent = round( (1 - summary_words / original_words) * 100, 1 ) m1, m2, m3, m4 = st.columns(4) m1.metric( "Thời gian", f"{elapsed}s" ) m2.metric( "Từ gốc", original_words ) m3.metric( "Từ tóm tắt", summary_words ) m4.metric( "Rút gọn", f"{reduction_percent}%" ) # ============================================================================= # FOOTER # ============================================================================= st.markdown("---") st.caption( "ViSum • Vietnamese News Summarization System • " "Powered by Hugging Face Transformers" )