import os import logging import math import time import uuid from fastapi import FastAPI, File, UploadFile, HTTPException from fastapi.responses import StreamingResponse from fastapi.middleware.cors import CORSMiddleware import io import fitz # PyMuPDF import json from transformers import pipeline from typing import Iterator, Optional import re # Model name: default Vietnamese-optimized model with fallback for CPU usage on Hugging Face Free tier MODEL_NAME = os.getenv("SUMMARIZER_MODEL_VI_VN", "VietAI/vit5-base-vietnamese") # Optimized for CPU usage on Hugging Face Free tier PRIMARY_VI_MODEL = MODEL_NAME FALLBACK_MODEL = "google/mt5-small" # Chunk and safety configuration (CPU-friendly), configurable via environment CHUNK_WORDS = int(os.getenv("CHUNK_WORDS", "600")) # smaller chunks to reduce per-chunk compute MAX_CHUNKS = int(os.getenv("MAX_CHUNKS", "20")) # safety limit to avoid long processing times logger = logging.getLogger("pdf_summarizer") logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s") app = FastAPI(title="PDF Summarizer with Streaming", version="0.1.0") # CORS: allow all origins app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ) # Summarizer instance loaded at startup, reused for all requests summarizer = None current_model_name = None @app.on_event("startup") def load_model(): global summarizer, current_model_name model_to_load = PRIMARY_VI_MODEL current_model_name = model_to_load try: logger.info(f"Loading Vietnamese model for CPU: {model_to_load}") summarizer = pipeline("summarization", model=model_to_load) logger.info("Vietnamese model loaded successfully.") except Exception as e: logger.warning(f"Failed to load Vietnamese model ({model_to_load}) due to: {e}. Falling back to MT5-small.") current_model_name = FALLBACK_MODEL summarizer = pipeline("summarization", model=FALLBACK_MODEL) logger.info("Fallback model MT5-small loaded.") def pdf_bytes_to_text(pdf_bytes: bytes) -> str: """ Read text from a PDF provided as memory bytes without writing to disk. """ doc = fitz.open(stream=pdf_bytes, filetype="pdf") texts = [] for page in doc: text = page.get_text("text") if text: texts.append(text) doc.close() return "\n".join(texts) def finalize_sentence(text: str) -> str: """ Ensure the final sentence ends with punctuation; if not, try to cut at last punctuation or append a period. """ t = text.strip() if not t: return t if t[-1] in ".!?": return t last_p = max(t.rfind("."), t.rfind("!"), t.rfind("?")) if last_p != -1 and last_p < len(t) - 1: return t[:last_p+1] return t + "." def iter_summaries(text: str, length_ratio: float, request_id: Optional[str] = None) -> Iterator[tuple[int, str, float]]: """ Chunk text into ~800-word blocks and yield a summary for each chunk. """ WORDS_PER_CHUNK = 800 words = text.split() chunks = [" ".join(words[i:i+WORDS_PER_CHUNK]) for i in range(0, len(words), WORDS_PER_CHUNK)] for idx, chunk in enumerate(chunks): chunk_word_count = len(chunk.split()) # Length penalty scales with chunk size to balance brevity vs coverage lp = 0.5 + min(1.5, (chunk_word_count / 1000) * 1.5) # min_length and max_length proportional to chunk size and length_ratio min_len = max(20, int(chunk_word_count * 0.05 * length_ratio)) max_len = max(min_len + 10, int(chunk_word_count * 0.25 * length_ratio)) try: t0 = time.time() result = summarizer( chunk, min_length=min_len, max_length=max_len, length_penalty=lp, repetition_penalty=2.5, no_repeat_ngram_size=3, num_beams=4 ) duration = time.time() - t0 summary = result[0]["summary_text"] if isinstance(result, list) else result["summary_text"] except Exception as e: summary = f"[summarization error: {str(e)}]" duration = 0.0 summary = finalize_sentence(summary) yield idx, summary, duration @app.post("/summarize") async def summarize(pdf_file: UploadFile = File(...), length_ratio: float = 0.5): """ Receive a PDF via memory (bytes) and return chunk-wise summaries as JSON Lines. """ if pdf_file.content_type != "application/pdf": raise HTTPException(status_code=400, detail="Only PDF files are supported.") if not (0.1 <= length_ratio <= 1.0): raise HTTPException(status_code=400, detail="length_ratio must be between 0.1 and 1.0") pdf_bytes = await pdf_file.read() text = pdf_bytes_to_text(pdf_bytes) if not text.strip(): raise HTTPException(status_code=400, detail="PDF contains no readable text.") # Safety guard: limit number of chunks to avoid long processing times on CPU/free tier total_words = len(text.split()) chunk_count = math.ceil(total_words / CHUNK_WORDS) if CHUNK_WORDS > 0 else 1 logger.info(f"Document text length: {total_words} words; chunks: {chunk_count}") if chunk_count > MAX_CHUNKS: raise HTTPException( status_code=400, detail=f"Document too long: requires {chunk_count} chunks (max {MAX_CHUNKS}). Please reduce the PDF size or length_ratio.", ) # Per-request identifiers and timing for enhanced logging request_id = uuid.uuid4().hex start_time = time.time() logger.info( f"Request {request_id}: starting. words={total_words}, chunks={chunk_count}, model={current_model_name}" ) def gen() -> Iterator[bytes]: durations = [] for idx, summary, duration in iter_summaries(text, length_ratio, request_id): durations.append(duration) avg = sum(durations) / len(durations) if durations else 0.0 remaining = max(0, chunk_count - idx - 1) est_sec = remaining * avg payload = { "request_id": request_id, "chunk": idx, "summary": summary, "estimate_seconds": round(est_sec, 2), } yield (json.dumps(payload) + "\n").encode("utf-8") # Finalize logging after streaming completes logger.info( f"Request {request_id} finished: chunks={chunk_count}, total_words={total_words}, model={current_model_name}, duration={time.time()-start_time:.2f}s" ) return StreamingResponse(gen(), media_type="application/jsonlines") @app.get("/health") async def health(): return {"status": "online"}