Spaces:
Runtime error
Runtime error
File size: 6,812 Bytes
e8e26ec | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 | import os
import logging
import math
import time
import uuid
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.responses import StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
import io
import fitz # PyMuPDF
import json
from transformers import pipeline
from typing import Iterator, Optional
import re
# Model name: default Vietnamese-optimized model with fallback for CPU usage on Hugging Face Free tier
MODEL_NAME = os.getenv("SUMMARIZER_MODEL_VI_VN", "VietAI/vit5-base-vietnamese")
# Optimized for CPU usage on Hugging Face Free tier
PRIMARY_VI_MODEL = MODEL_NAME
FALLBACK_MODEL = "google/mt5-small"
# Chunk and safety configuration (CPU-friendly), configurable via environment
CHUNK_WORDS = int(os.getenv("CHUNK_WORDS", "600")) # smaller chunks to reduce per-chunk compute
MAX_CHUNKS = int(os.getenv("MAX_CHUNKS", "20")) # safety limit to avoid long processing times
logger = logging.getLogger("pdf_summarizer")
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
app = FastAPI(title="PDF Summarizer with Streaming", version="0.1.0")
# CORS: allow all origins
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# Summarizer instance loaded at startup, reused for all requests
summarizer = None
current_model_name = None
@app.on_event("startup")
def load_model():
global summarizer, current_model_name
model_to_load = PRIMARY_VI_MODEL
current_model_name = model_to_load
try:
logger.info(f"Loading Vietnamese model for CPU: {model_to_load}")
summarizer = pipeline("summarization", model=model_to_load)
logger.info("Vietnamese model loaded successfully.")
except Exception as e:
logger.warning(f"Failed to load Vietnamese model ({model_to_load}) due to: {e}. Falling back to MT5-small.")
current_model_name = FALLBACK_MODEL
summarizer = pipeline("summarization", model=FALLBACK_MODEL)
logger.info("Fallback model MT5-small loaded.")
def pdf_bytes_to_text(pdf_bytes: bytes) -> str:
"""
Read text from a PDF provided as memory bytes without writing to disk.
"""
doc = fitz.open(stream=pdf_bytes, filetype="pdf")
texts = []
for page in doc:
text = page.get_text("text")
if text:
texts.append(text)
doc.close()
return "\n".join(texts)
def finalize_sentence(text: str) -> str:
"""
Ensure the final sentence ends with punctuation; if not, try to cut at last punctuation or append a period.
"""
t = text.strip()
if not t:
return t
if t[-1] in ".!?":
return t
last_p = max(t.rfind("."), t.rfind("!"), t.rfind("?"))
if last_p != -1 and last_p < len(t) - 1:
return t[:last_p+1]
return t + "."
def iter_summaries(text: str, length_ratio: float, request_id: Optional[str] = None) -> Iterator[tuple[int, str, float]]:
"""
Chunk text into ~800-word blocks and yield a summary for each chunk.
"""
WORDS_PER_CHUNK = 800
words = text.split()
chunks = [" ".join(words[i:i+WORDS_PER_CHUNK]) for i in range(0, len(words), WORDS_PER_CHUNK)]
for idx, chunk in enumerate(chunks):
chunk_word_count = len(chunk.split())
# Length penalty scales with chunk size to balance brevity vs coverage
lp = 0.5 + min(1.5, (chunk_word_count / 1000) * 1.5)
# min_length and max_length proportional to chunk size and length_ratio
min_len = max(20, int(chunk_word_count * 0.05 * length_ratio))
max_len = max(min_len + 10, int(chunk_word_count * 0.25 * length_ratio))
try:
t0 = time.time()
result = summarizer(
chunk,
min_length=min_len,
max_length=max_len,
length_penalty=lp,
repetition_penalty=2.5,
no_repeat_ngram_size=3,
num_beams=4
)
duration = time.time() - t0
summary = result[0]["summary_text"] if isinstance(result, list) else result["summary_text"]
except Exception as e:
summary = f"[summarization error: {str(e)}]"
duration = 0.0
summary = finalize_sentence(summary)
yield idx, summary, duration
@app.post("/summarize")
async def summarize(pdf_file: UploadFile = File(...), length_ratio: float = 0.5):
"""
Receive a PDF via memory (bytes) and return chunk-wise summaries as JSON Lines.
"""
if pdf_file.content_type != "application/pdf":
raise HTTPException(status_code=400, detail="Only PDF files are supported.")
if not (0.1 <= length_ratio <= 1.0):
raise HTTPException(status_code=400, detail="length_ratio must be between 0.1 and 1.0")
pdf_bytes = await pdf_file.read()
text = pdf_bytes_to_text(pdf_bytes)
if not text.strip():
raise HTTPException(status_code=400, detail="PDF contains no readable text.")
# Safety guard: limit number of chunks to avoid long processing times on CPU/free tier
total_words = len(text.split())
chunk_count = math.ceil(total_words / CHUNK_WORDS) if CHUNK_WORDS > 0 else 1
logger.info(f"Document text length: {total_words} words; chunks: {chunk_count}")
if chunk_count > MAX_CHUNKS:
raise HTTPException(
status_code=400,
detail=f"Document too long: requires {chunk_count} chunks (max {MAX_CHUNKS}). Please reduce the PDF size or length_ratio.",
)
# Per-request identifiers and timing for enhanced logging
request_id = uuid.uuid4().hex
start_time = time.time()
logger.info(
f"Request {request_id}: starting. words={total_words}, chunks={chunk_count}, model={current_model_name}"
)
def gen() -> Iterator[bytes]:
durations = []
for idx, summary, duration in iter_summaries(text, length_ratio, request_id):
durations.append(duration)
avg = sum(durations) / len(durations) if durations else 0.0
remaining = max(0, chunk_count - idx - 1)
est_sec = remaining * avg
payload = {
"request_id": request_id,
"chunk": idx,
"summary": summary,
"estimate_seconds": round(est_sec, 2),
}
yield (json.dumps(payload) + "\n").encode("utf-8")
# Finalize logging after streaming completes
logger.info(
f"Request {request_id} finished: chunks={chunk_count}, total_words={total_words}, model={current_model_name}, duration={time.time()-start_time:.2f}s"
)
return StreamingResponse(gen(), media_type="application/jsonlines")
@app.get("/health")
async def health():
return {"status": "online"}
|