Spaces:
Sleeping
Sleeping
| import time | |
| import logging | |
| from fastapi import FastAPI, Request | |
| from pydantic import BaseModel | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
| import torch | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| app = FastAPI() | |
| # Load model and tokenizer | |
| tokenizer = AutoTokenizer.from_pretrained("VietAI/vit5-large-vietnews-summarization") | |
| model = AutoModelForSeq2SeqLM.from_pretrained("VietAI/vit5-large-vietnews-summarization") | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model.to(device) | |
| # Warm-up model to reduce first-request latency | |
| dummy_input = tokenizer("Tin nhanh: Đây là văn bản mẫu để warmup mô hình.", return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| _ = model.generate(**dummy_input, max_length=32) | |
| class SummarizeRequest(BaseModel): | |
| text: str | |
| async def root(): | |
| return {"message": "Model is ready."} | |
| async def summarize(req: Request, body: SummarizeRequest): | |
| start_time = time.time() | |
| client_ip = req.client.host | |
| logger.info(f"[{time.strftime('%Y-%m-%d %H:%M:%S')}] 🔵 Received request from {client_ip}") | |
| text = body.text.strip() | |
| # Tiền xử lý: nếu không giống tin tức thì thêm "Tin nhanh:" | |
| if not text.lower().startswith(("theo", "trong khi", "bộ", "ngày", "việt nam", "công an")): | |
| text = "Tin nhanh: " + text | |
| else: | |
| text = "Vietnews: " + text | |
| input_text = text + " </s>" | |
| encoding = tokenizer(input_text, return_tensors="pt") | |
| input_ids = encoding["input_ids"].to(device) | |
| attention_mask = encoding["attention_mask"].to(device) | |
| # Sinh tóm tắt với cấu hình ổn định (loại bỏ early_stopping và dùng greedy decoding) | |
| outputs = model.generate( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| max_length=256, | |
| num_beams=1, # greedy decoding | |
| no_repeat_ngram_size=2 | |
| ) | |
| summary = tokenizer.decode(outputs[0], skip_special_tokens=True, clean_up_tokenization_spaces=True) | |
| end_time = time.time() | |
| logger.info(f"[{time.strftime('%Y-%m-%d %H:%M:%S')}] ✅ Response sent — total time: {end_time - start_time:.2f}s") | |
| return {"summary": summary} | |