Spaces:
Sleeping
Sleeping
| """ | |
| FastAPI Text Summarization App for Hugging Face CPU Space | |
| Model: vinai/bartpho-syllable-base | |
| """ | |
| import re | |
| from typing import Optional | |
| from fastapi import FastAPI, File, UploadFile, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import StreamingResponse | |
| from pydantic import BaseModel | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
| import torch | |
| import fitz # PyMuPDF | |
| import json | |
| import gc | |
| import asyncio | |
| # ============================================================ | |
| # Initialize FastAPI App | |
| # ============================================================ | |
| app = FastAPI( | |
| title="Vietnamese Text Summarizer", | |
| description="Summarize Vietnamese text using BARTpho model", | |
| version="1.0.0" | |
| ) | |
| # CORS middleware - Allow All Origins for GitHub Pages | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # ============================================================ | |
| # Load Model | |
| # ============================================================ | |
| print("Loading BARTpho model...") | |
| MODEL_NAME = "vinai/bartpho-syllable-base" | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
| model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME) | |
| model.eval() # Set to evaluation mode | |
| print("Model loaded successfully!") | |
| # ============================================================ | |
| # Request Models | |
| # ============================================================ | |
| class SummarizeRequest(BaseModel): | |
| text: str | |
| length_level: int = 1 # 0: Ngắn (2-3 ý), 1: Trung bình (4-5 ý), 2: Chi tiết (6+ ý) | |
| # ============================================================ | |
| # Helper Functions | |
| # ============================================================ | |
| def chunk_text_by_words(text: str, max_words: int = 800) -> list[str]: | |
| """ | |
| Chia văn bản thành các đoạn tối đa max_words từ. | |
| Giữ nguyên câu hoàn chỉnh khi có thể. | |
| """ | |
| # Clean text | |
| text = re.sub(r'\s+', ' ', text).strip() | |
| # Split into sentences | |
| sentences = re.split(r'(?<=[.!?])\s+', text) | |
| chunks = [] | |
| current_chunk = [] | |
| current_word_count = 0 | |
| for sentence in sentences: | |
| sentence_words = sentence.split() | |
| sentence_word_count = len(sentence_words) | |
| # Nếu câu đơn lẻ dài hơn max_words, chia nhỏ câu đó | |
| if sentence_word_count > max_words: | |
| if current_chunk: | |
| chunks.append(' '.join(current_chunk)) | |
| current_chunk = [] | |
| current_word_count = 0 | |
| for i in range(0, sentence_word_count, max_words): | |
| chunk_words = sentence_words[i:i + max_words] | |
| chunks.append(' '.join(chunk_words)) | |
| # Nếu thêm câu này vượt quá giới hạn | |
| elif current_word_count + sentence_word_count > max_words: | |
| if current_chunk: | |
| chunks.append(' '.join(current_chunk)) | |
| current_chunk = [sentence] | |
| current_word_count = sentence_word_count | |
| else: | |
| current_chunk.append(sentence) | |
| current_word_count += sentence_word_count | |
| # Lưu chunk cuối cùng | |
| if current_chunk: | |
| chunks.append(' '.join(current_chunk)) | |
| return chunks | |
| def fix_truncated_text(text: str) -> str: | |
| """ | |
| Nếu kết quả không kết thúc bằng dấu câu, | |
| tự động cắt đến dấu chấm gần nhất. | |
| """ | |
| text = text.strip() | |
| if not text: | |
| return text | |
| # Nếu đã kết thúc bằng dấu câu, trả về nguyên | |
| if text[-1] in '.!?': | |
| return text | |
| # Tìm dấu câu gần nhất | |
| last_period = text.rfind('.') | |
| last_exclaim = text.rfind('!') | |
| last_question = text.rfind('?') | |
| last_sentence_end = max(last_period, last_exclaim, last_question) | |
| if last_sentence_end > 0: | |
| # Cắt đến dấu câu gần nhất | |
| return text[:last_sentence_end + 1] | |
| # Nếu không có dấu câu nào, thêm dấu chấm | |
| return text + '.' | |
| def format_as_bullet_points(summaries: list[str], max_points: int = None) -> str: | |
| """ | |
| Chuyển đổi các đoạn tóm tắt thành Bullet Points. | |
| Mỗi ý một dòng, bắt đầu bằng '•'. | |
| max_points: Số lượng bullet points tối đa (None = không giới hạn) | |
| """ | |
| bullet_points = [] | |
| for summary in summaries: | |
| # Chia thành các câu | |
| sentences = re.split(r'(?<=[.!?])\s+', summary) | |
| for sentence in sentences: | |
| sentence = sentence.strip() | |
| if sentence and len(sentence) > 15: # Bỏ qua câu quá ngắn | |
| # Đảm bảo câu kết thúc đúng | |
| sentence = fix_truncated_text(sentence) | |
| bullet_points.append(f"• {sentence}") | |
| # Giới hạn số lượng bullet points nếu có | |
| if max_points is not None and len(bullet_points) > max_points: | |
| bullet_points = bullet_points[:max_points] | |
| return '\n'.join(bullet_points) | |
| def generate_summary(text: str) -> str: | |
| """ | |
| Sinh tóm tắt với torch.no_grad() để tiết kiệm RAM. | |
| Tham số: max_length=350, min_length=100, num_beams=4, repetition_penalty=2.5 | |
| """ | |
| try: | |
| # Tokenize input | |
| inputs = tokenizer( | |
| text, | |
| max_length=1024, | |
| truncation=True, | |
| return_tensors="pt" | |
| ) | |
| # Generate với torch.no_grad() để tiết kiệm RAM | |
| with torch.no_grad(): | |
| summary_ids = model.generate( | |
| inputs["input_ids"], | |
| attention_mask=inputs["attention_mask"], | |
| max_length=350, | |
| min_length=100, | |
| num_beams=4, | |
| repetition_penalty=2.5, | |
| no_repeat_ngram_size=3, | |
| early_stopping=True | |
| ) | |
| # Decode output | |
| summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True) | |
| # Fix truncated text | |
| summary = fix_truncated_text(summary) | |
| return summary | |
| except Exception as e: | |
| print(f"Error generating summary: {e}") | |
| return "" | |
| def summarize_long_text(text: str) -> list[str]: | |
| """ | |
| Nếu văn bản > 800 từ, chia nhỏ và tóm tắt từng phần. | |
| """ | |
| word_count = len(text.split()) | |
| # Nếu văn bản ngắn, tóm tắt trực tiếp | |
| if word_count <= 800: | |
| summary = generate_summary(text) | |
| return [summary] if summary else [] | |
| # Chia nhỏ văn bản dài | |
| chunks = chunk_text_by_words(text, max_words=800) | |
| summaries = [] | |
| for i, chunk in enumerate(chunks): | |
| print(f"Processing chunk {i + 1}/{len(chunks)}...") | |
| summary = generate_summary(chunk) | |
| if summary: | |
| summaries.append(summary) | |
| return summaries | |
| def extract_text_from_pdf_bytes(pdf_bytes: bytes) -> str: | |
| """ | |
| Đọc PDF từ byte stream sử dụng PyMuPDF. | |
| KHÔNG lưu file ra đĩa. | |
| """ | |
| try: | |
| # Mở PDF từ byte stream | |
| doc = fitz.open(stream=pdf_bytes, filetype="pdf") | |
| text_parts = [] | |
| for page_num in range(len(doc)): | |
| page = doc[page_num] | |
| text = page.get_text("text") | |
| if text: | |
| text_parts.append(text) | |
| doc.close() | |
| full_text = '\n'.join(text_parts) | |
| if not full_text.strip(): | |
| raise HTTPException( | |
| status_code=400, | |
| detail="Không thể trích xuất văn bản từ PDF. File có thể là ảnh scan." | |
| ) | |
| return full_text | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Lỗi khi đọc PDF: {str(e)}" | |
| ) | |
| # ============================================================ | |
| # API Endpoints | |
| # ============================================================ | |
| async def root(): | |
| """Health check endpoint.""" | |
| return { | |
| "status": "running", | |
| "model": "vinai/bartpho-syllable-base", | |
| "endpoints": ["/summarize", "/upload-pdf"] | |
| } | |
| async def health_check(): | |
| """ | |
| Health check cho Frontend kiểm tra Space đã khởi động chưa. | |
| Không chạy qua model AI - phản hồi tức thì. | |
| """ | |
| return {"status": "online"} | |
| async def summarize_text(request: SummarizeRequest): | |
| """ | |
| Tóm tắt văn bản tiếng Việt. | |
| Trả về danh sách Bullet Points. | |
| length_level: 0 = Ngắn (2-3 ý), 1 = Trung bình (4-5 ý), 2 = Chi tiết (6+ ý) | |
| """ | |
| text = request.text | |
| length_level = request.length_level | |
| # Map length_level to max_points | |
| max_points_map = { | |
| 0: 3, # Ngắn: 2-3 ý | |
| 1: 5, # Trung bình: 4-5 ý | |
| 2: None # Chi tiết: không giới hạn | |
| } | |
| max_points = max_points_map.get(length_level, 5) | |
| # Validate text | |
| if not text or len(text.strip()) < 50: | |
| raise HTTPException( | |
| status_code=400, | |
| detail="Văn bản quá ngắn để tóm tắt (cần ít nhất 50 ký tự)." | |
| ) | |
| # Return StreamingResponse | |
| return StreamingResponse( | |
| stream_summary_generator(text, max_points), | |
| media_type="application/x-ndjson" | |
| ) | |
| async def stream_summary_generator(text: str, max_points: int = None): | |
| """ | |
| Generator function to stream summary chunks. | |
| Implements: Recursive chunking, Context-aware summarization, Memory optimization. | |
| YIELDS JSON: {"text": "• Point 1\n", "done": False} + "\n" | |
| """ | |
| chunks = chunk_text_by_words(text, max_words=800) | |
| total_chunks = len(chunks) | |
| # Context cho chunk tiếp theo (summary của chunk trước) | |
| context_summary = "" | |
| bullet_count = 0 | |
| for i, chunk in enumerate(chunks): | |
| # 1. Prepare input: Context + Current Chunk | |
| # Nếu có context, nối vào đầu chunk (có phân cách) | |
| if context_summary: | |
| # Giới hạn context để tránh quá dài (lấy 200 ký tự cuối của summary trước) | |
| short_context = context_summary[-200:] if len(context_summary) > 200 else context_summary | |
| input_text = f"Tóm tắt tiếp theo ngữ cảnh: {short_context}\nNội dung: {chunk}" | |
| else: | |
| input_text = chunk | |
| # 2. Generate Summary | |
| # Chạy trong threadpool để không chặn event loop của FastAPI | |
| try: | |
| summary_part = await asyncio.to_thread(generate_summary, input_text) | |
| except Exception as e: | |
| error_json = json.dumps({"error": str(e), "done": True}) | |
| yield error_json + "\n" | |
| return | |
| # 3. Format as bullets | |
| # Chỉ lấy max_points còn lại nếu có giới hạn | |
| points_limit = None | |
| if max_points is not None: | |
| points_limit = max_points - bullet_count | |
| if points_limit <= 0: | |
| break # Đã đủ số ý | |
| bullets_text = format_as_bullet_points([summary_part], max_points=points_limit) | |
| if bullets_text: | |
| # Update context for next iteration | |
| context_summary = summary_part.replace('\n', ' ') | |
| # Count bullets | |
| new_points = bullets_text.count('•') | |
| bullet_count += new_points | |
| # Yield Result | |
| result_json = json.dumps({ | |
| "text": bullets_text + "\n", | |
| "done": False, | |
| "progress": int((i + 1) / total_chunks * 100) | |
| }) | |
| yield result_json + "\n" | |
| # 4. Memory Optimization | |
| try: | |
| del input_text | |
| del summary_part | |
| except UnboundLocalError: | |
| pass | |
| gc.collect() # Force garbage collection | |
| # Nhường CPU cho request khác 1 chút | |
| await asyncio.sleep(0.1) | |
| # Final message | |
| yield json.dumps({"text": "", "done": True, "progress": 100}) + "\n" | |
| async def upload_pdf(file: UploadFile = File(...), length_level: int = 1): | |
| """ | |
| Upload và tóm tắt file PDF. | |
| Đọc qua byte stream, KHÔNG lưu file ra đĩa. | |
| Trả về StreamingResponse (NDJSON). | |
| length_level: 0 = Ngắn (2-3 ý), 1 = Trung bình (4-5 ý), 2 = Chi tiết (6+ ý) | |
| """ | |
| # Map length_level to max_points | |
| max_points_map = { | |
| 0: 3, # Ngắn: 2-3 ý | |
| 1: 5, # Trung bình: 4-5 ý | |
| 2: None # Chi tiết: không giới hạn | |
| } | |
| max_points = max_points_map.get(length_level, 5) | |
| # Validate file type | |
| if not file.filename.lower().endswith('.pdf'): | |
| raise HTTPException( | |
| status_code=400, | |
| detail="Chỉ hỗ trợ file PDF." | |
| ) | |
| # Đọc file qua contents = await file.read() | |
| contents = await file.read() | |
| if len(contents) == 0: | |
| raise HTTPException( | |
| status_code=400, | |
| detail="File rỗng." | |
| ) | |
| # Limit file size (10MB max) | |
| max_size = 10 * 1024 * 1024 # 10MB | |
| if len(contents) > max_size: | |
| raise HTTPException( | |
| status_code=400, | |
| detail="File quá lớn. Giới hạn 10MB." | |
| ) | |
| # Extract text from PDF bytes | |
| text = extract_text_from_pdf_bytes(contents) | |
| # Validate extracted text | |
| if len(text.strip()) < 50: | |
| raise HTTPException( | |
| status_code=400, | |
| detail="Văn bản trích xuất từ PDF quá ngắn." | |
| ) | |
| # Return StreamingResponse | |
| return StreamingResponse( | |
| stream_summary_generator(text, max_points), | |
| media_type="application/x-ndjson" | |
| ) | |
| # ============================================================ | |
| # Run with Uvicorn (for local development) | |
| # ============================================================ | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) | |