Tungdabiban's picture
Upload 2 files
e8e26ec verified
import os
import logging
import math
import time
import uuid
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.responses import StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
import io
import fitz # PyMuPDF
import json
from transformers import pipeline
from typing import Iterator, Optional
import re
# Model name: default Vietnamese-optimized model with fallback for CPU usage on Hugging Face Free tier
MODEL_NAME = os.getenv("SUMMARIZER_MODEL_VI_VN", "VietAI/vit5-base-vietnamese")
# Optimized for CPU usage on Hugging Face Free tier
PRIMARY_VI_MODEL = MODEL_NAME
FALLBACK_MODEL = "google/mt5-small"
# Chunk and safety configuration (CPU-friendly), configurable via environment
CHUNK_WORDS = int(os.getenv("CHUNK_WORDS", "600")) # smaller chunks to reduce per-chunk compute
MAX_CHUNKS = int(os.getenv("MAX_CHUNKS", "20")) # safety limit to avoid long processing times
logger = logging.getLogger("pdf_summarizer")
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
app = FastAPI(title="PDF Summarizer with Streaming", version="0.1.0")
# CORS: allow all origins
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# Summarizer instance loaded at startup, reused for all requests
summarizer = None
current_model_name = None
@app.on_event("startup")
def load_model():
global summarizer, current_model_name
model_to_load = PRIMARY_VI_MODEL
current_model_name = model_to_load
try:
logger.info(f"Loading Vietnamese model for CPU: {model_to_load}")
summarizer = pipeline("summarization", model=model_to_load)
logger.info("Vietnamese model loaded successfully.")
except Exception as e:
logger.warning(f"Failed to load Vietnamese model ({model_to_load}) due to: {e}. Falling back to MT5-small.")
current_model_name = FALLBACK_MODEL
summarizer = pipeline("summarization", model=FALLBACK_MODEL)
logger.info("Fallback model MT5-small loaded.")
def pdf_bytes_to_text(pdf_bytes: bytes) -> str:
"""
Read text from a PDF provided as memory bytes without writing to disk.
"""
doc = fitz.open(stream=pdf_bytes, filetype="pdf")
texts = []
for page in doc:
text = page.get_text("text")
if text:
texts.append(text)
doc.close()
return "\n".join(texts)
def finalize_sentence(text: str) -> str:
"""
Ensure the final sentence ends with punctuation; if not, try to cut at last punctuation or append a period.
"""
t = text.strip()
if not t:
return t
if t[-1] in ".!?":
return t
last_p = max(t.rfind("."), t.rfind("!"), t.rfind("?"))
if last_p != -1 and last_p < len(t) - 1:
return t[:last_p+1]
return t + "."
def iter_summaries(text: str, length_ratio: float, request_id: Optional[str] = None) -> Iterator[tuple[int, str, float]]:
"""
Chunk text into ~800-word blocks and yield a summary for each chunk.
"""
WORDS_PER_CHUNK = 800
words = text.split()
chunks = [" ".join(words[i:i+WORDS_PER_CHUNK]) for i in range(0, len(words), WORDS_PER_CHUNK)]
for idx, chunk in enumerate(chunks):
chunk_word_count = len(chunk.split())
# Length penalty scales with chunk size to balance brevity vs coverage
lp = 0.5 + min(1.5, (chunk_word_count / 1000) * 1.5)
# min_length and max_length proportional to chunk size and length_ratio
min_len = max(20, int(chunk_word_count * 0.05 * length_ratio))
max_len = max(min_len + 10, int(chunk_word_count * 0.25 * length_ratio))
try:
t0 = time.time()
result = summarizer(
chunk,
min_length=min_len,
max_length=max_len,
length_penalty=lp,
repetition_penalty=2.5,
no_repeat_ngram_size=3,
num_beams=4
)
duration = time.time() - t0
summary = result[0]["summary_text"] if isinstance(result, list) else result["summary_text"]
except Exception as e:
summary = f"[summarization error: {str(e)}]"
duration = 0.0
summary = finalize_sentence(summary)
yield idx, summary, duration
@app.post("/summarize")
async def summarize(pdf_file: UploadFile = File(...), length_ratio: float = 0.5):
"""
Receive a PDF via memory (bytes) and return chunk-wise summaries as JSON Lines.
"""
if pdf_file.content_type != "application/pdf":
raise HTTPException(status_code=400, detail="Only PDF files are supported.")
if not (0.1 <= length_ratio <= 1.0):
raise HTTPException(status_code=400, detail="length_ratio must be between 0.1 and 1.0")
pdf_bytes = await pdf_file.read()
text = pdf_bytes_to_text(pdf_bytes)
if not text.strip():
raise HTTPException(status_code=400, detail="PDF contains no readable text.")
# Safety guard: limit number of chunks to avoid long processing times on CPU/free tier
total_words = len(text.split())
chunk_count = math.ceil(total_words / CHUNK_WORDS) if CHUNK_WORDS > 0 else 1
logger.info(f"Document text length: {total_words} words; chunks: {chunk_count}")
if chunk_count > MAX_CHUNKS:
raise HTTPException(
status_code=400,
detail=f"Document too long: requires {chunk_count} chunks (max {MAX_CHUNKS}). Please reduce the PDF size or length_ratio.",
)
# Per-request identifiers and timing for enhanced logging
request_id = uuid.uuid4().hex
start_time = time.time()
logger.info(
f"Request {request_id}: starting. words={total_words}, chunks={chunk_count}, model={current_model_name}"
)
def gen() -> Iterator[bytes]:
durations = []
for idx, summary, duration in iter_summaries(text, length_ratio, request_id):
durations.append(duration)
avg = sum(durations) / len(durations) if durations else 0.0
remaining = max(0, chunk_count - idx - 1)
est_sec = remaining * avg
payload = {
"request_id": request_id,
"chunk": idx,
"summary": summary,
"estimate_seconds": round(est_sec, 2),
}
yield (json.dumps(payload) + "\n").encode("utf-8")
# Finalize logging after streaming completes
logger.info(
f"Request {request_id} finished: chunks={chunk_count}, total_words={total_words}, model={current_model_name}, duration={time.time()-start_time:.2f}s"
)
return StreamingResponse(gen(), media_type="application/jsonlines")
@app.get("/health")
async def health():
return {"status": "online"}