Spaces:
Sleeping
Sleeping
| #thư viện | |
| import time | |
| import torch | |
| import streamlit as st | |
| import re | |
| def fix_bartpho_output(text: str) -> str: | |
| """ | |
| BARTpho syllable hay bị dính từ kiểu: | |
| 'Coquan' → 'Cơ quan', 'đốitượng' → 'đối tượng' | |
| Hàm này thêm dấu cách trước chữ hoa giữa câu | |
| và fix một số pattern hay gặp. | |
| """ | |
| # Thêm space trước chữ hoa nằm giữa từ thường | |
| # Ví dụ: "CơQuan" → "Cơ Quan" | |
| text = re.sub(r'([a-zđàáâãèéêìíòóôõùúýăắặấầẩẫậắằẳẵặ])' | |
| r'([A-ZĐÀÁÂÃÈÉÊÌÍÒÓÔÕÙÚÝĂẮẶẤẦẨẪẬẮẰẲẴẶ])', | |
| r'\1 \2', text) | |
| # Fix dấu câu dính vào chữ: "vong.Cơ" → "vong. Cơ" | |
| text = re.sub(r'([.!?,;:])([^\s])', r'\1 \2', text) | |
| # Xóa khoảng trắng thừa | |
| text = re.sub(r' +', ' ', text).strip() | |
| return text | |
| # HuggingFace Transformers | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
| from peft import PeftModel, PeftConfig | |
| #page config | |
| st.set_page_config( | |
| page_title="ViSum - Vietnamese News Summarization", | |
| page_icon="🇻🇳", | |
| layout="wide" | |
| ) | |
| #custom css | |
| st.markdown(""" | |
| <style> | |
| /* Font toàn app */ | |
| html, body, [class*="css"] { | |
| font-family: 'Arial', sans-serif; | |
| } | |
| /* Nút chính */ | |
| .stButton > button[kind="primary"] { | |
| background-color: #1a73e8; | |
| color: white; | |
| border: none; | |
| border-radius: 10px; | |
| padding: 0.6rem 1.5rem; | |
| font-size: 16px; | |
| font-weight: 600; | |
| } | |
| .stButton > button[kind="primary"]:hover { | |
| background-color: #1557b0; | |
| } | |
| /* Text area */ | |
| .stTextArea textarea { | |
| border-radius: 10px; | |
| border: 1px solid #d0d0d0; | |
| line-height: 1.6; | |
| font-size: 15px; | |
| } | |
| /* Metric cards */ | |
| [data-testid="metric-container"] { | |
| background-color: #f8f9fa; | |
| border: 1px solid #e0e0e0; | |
| padding: 15px; | |
| border-radius: 12px; | |
| } | |
| /* Responsive cho mobile */ | |
| @media (max-width: 768px) { | |
| h1 { | |
| font-size: 1.8rem; | |
| } | |
| .stTextArea textarea { | |
| font-size: 14px; | |
| } | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| #model config | |
| MODEL_ID = "OrdinaryAI/visum-qlora-5epochs" | |
| # ============================================================================= | |
| # LOAD MODEL | |
| # | |
| # @st.cache_resource: | |
| # Streamlit chỉ load model 1 lần duy nhất | |
| # Những lần sau dùng cache -> app nhanh hơn rất nhiều | |
| # ============================================================================= | |
| def load_model(model_id): | |
| # Đọc config PEFT để biết model gốc là gì | |
| peft_config = PeftConfig.from_pretrained(model_id) | |
| # Load model gốc (vinai/bartpho-syllable) | |
| base_model = AutoModelForSeq2SeqLM.from_pretrained( | |
| peft_config.base_model_name_or_path | |
| ) | |
| # Gắn trọng số QLoRA vào | |
| model = PeftModel.from_pretrained(base_model, model_id) | |
| # Merge vào model gốc → inference nhanh hơn | |
| model = model.merge_and_unload() | |
| # Load tokenizer từ model gốc | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| peft_config.base_model_name_or_path | |
| ) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model = model.to(device) | |
| model.eval() | |
| return tokenizer, model | |
| # ============================================================================= | |
| # HÀM TÓM TẮT | |
| # | |
| # Pipeline: | |
| # text | |
| # -> tokenize | |
| # -> model.generate() | |
| # -> decode | |
| # ============================================================================= | |
| def summarize_text( | |
| text, | |
| tokenizer, | |
| model, | |
| max_length=150, | |
| min_length=50, | |
| num_beams=4 | |
| ): | |
| # Lấy device hiện tại của model | |
| device = next(model.parameters()).device | |
| # Bắt đầu tính thời gian xử lý | |
| start_time = time.time() | |
| # ========================================================= | |
| # TOKENIZE | |
| # Chuyển text -> tensor số | |
| # ========================================================= | |
| inputs = tokenizer( | |
| text, | |
| return_tensors="pt", | |
| max_length=1024, | |
| truncation=True, | |
| padding=True | |
| ).to(device) | |
| # ========================================================= | |
| # GENERATE SUMMARY | |
| # ========================================================= | |
| with torch.no_grad(): | |
| output_ids = model.generate( | |
| input_ids=inputs["input_ids"], | |
| attention_mask=inputs["attention_mask"], | |
| max_length=max_length, | |
| min_length=min_length, | |
| num_beams=num_beams, | |
| early_stopping=True, | |
| # Tránh lặp cụm từ | |
| no_repeat_ngram_size=3 | |
| ) | |
| # ========================================================= | |
| # DECODE | |
| # Token IDs -> text | |
| # ========================================================= | |
| summary = tokenizer.decode(output_ids[0], skip_special_tokens=True) | |
| # Thêm dòng này để fix dính từ | |
| summary = fix_bartpho_output(summary) | |
| elapsed_time = round(time.time() - start_time, 2) | |
| return { | |
| "summary": summary, | |
| "time": elapsed_time | |
| } | |
| # ============================================================================= | |
| # SIDEBAR | |
| # ============================================================================= | |
| with st.sidebar: | |
| st.markdown("# 🇻🇳 ViSum") | |
| st.caption("Vietnamese News Summarization") | |
| st.markdown("---") | |
| st.subheader("⚙️ Cài đặt") | |
| # Slider độ dài tối đa | |
| max_length = st.slider( | |
| "Độ dài tối đa", | |
| min_value=50, | |
| max_value=500, | |
| value=150, | |
| step=10 | |
| ) | |
| # Slider độ dài tối thiểu | |
| min_length = st.slider( | |
| "Độ dài tối thiểu", | |
| min_value=10, | |
| max_value=200, | |
| value=50, | |
| step=10 | |
| ) | |
| # Beam search | |
| num_beams = st.slider( | |
| "Beam Search", | |
| min_value=1, | |
| max_value=8, | |
| value=4, | |
| step=1 | |
| ) | |
| st.markdown("---") | |
| st.caption(f"Model: {MODEL_ID}") | |
| st.caption("Ordinary-AI-Engineer") | |
| # ============================================================================= | |
| # MAIN UI | |
| # ============================================================================= | |
| st.title("ViSum - Hệ thống Tóm tắt Báo chí Tiếng Việt") | |
| st.markdown(""" | |
| Dán bài báo hoặc đoạn văn tiếng Việt vào ô bên dưới, | |
| sau đó nhấn **Tóm tắt** để AI tạo bản tóm tắt ngắn gọn. | |
| """) | |
| # ============================================================================= | |
| # INPUT TEXT AREA | |
| # ============================================================================= | |
| input_text = st.text_area( | |
| label="Văn bản gốc", | |
| placeholder="Nhập nội dung tại đây...", | |
| height=320 | |
| ) | |
| # ============================================================================= | |
| # BUTTON | |
| # ============================================================================= | |
| col1, col2, col3 = st.columns([1, 2, 1]) | |
| with col2: | |
| summarize_button = st.button( | |
| "Tóm tắt", | |
| type="primary", | |
| use_container_width=True | |
| ) | |
| # ============================================================================= | |
| # XỬ LÝ KHI USER NHẤN NÚT | |
| # ============================================================================= | |
| if summarize_button: | |
| # Xóa khoảng trắng thừa | |
| clean_text = input_text.strip() | |
| # ========================================================= | |
| # VALIDATION | |
| # ========================================================= | |
| if not clean_text: | |
| st.error("Vui lòng nhập văn bản!") | |
| elif len(clean_text) < 100: | |
| st.warning( | |
| "Văn bản quá ngắn! " | |
| "Kết quả tóm tắt có thể không chính xác." | |
| ) | |
| else: | |
| #load model | |
| with st.spinner("Đang load model..."): | |
| tokenizer, model = load_model(MODEL_ID) | |
| # ===================================================== | |
| # SUMMARIZE | |
| # ===================================================== | |
| with st.spinner("Đang tóm tắt văn bản, xin hãy chờ trong giấy lát."): | |
| result = summarize_text( | |
| text=clean_text, | |
| tokenizer=tokenizer, | |
| model=model, | |
| max_length=max_length, | |
| min_length=min_length, | |
| num_beams=num_beams | |
| ) | |
| summary = result["summary"] | |
| elapsed = result["time"] | |
| # ===================================================== | |
| # OUTPUT | |
| # ===================================================== | |
| st.success("Tóm tắt hoàn thành!") | |
| st.text_area( | |
| label="Kết quả tóm tắt: ", | |
| value=summary, | |
| height=220 | |
| ) | |
| # ===================================================== | |
| # METRICS | |
| # ===================================================== | |
| original_words = len(clean_text.split()) | |
| summary_words = len(summary.split()) | |
| reduction_percent = round( | |
| (1 - summary_words / original_words) * 100, | |
| 1 | |
| ) | |
| m1, m2, m3, m4 = st.columns(4) | |
| m1.metric( | |
| "Thời gian", | |
| f"{elapsed}s" | |
| ) | |
| m2.metric( | |
| "Từ gốc", | |
| original_words | |
| ) | |
| m3.metric( | |
| "Từ tóm tắt", | |
| summary_words | |
| ) | |
| m4.metric( | |
| "Rút gọn", | |
| f"{reduction_percent}%" | |
| ) | |
| # ============================================================================= | |
| # FOOTER | |
| # ============================================================================= | |
| st.markdown("---") | |
| st.caption( | |
| "ViSum • Vietnamese News Summarization System • " | |
| "Powered by Hugging Face Transformers" | |
| ) | |