File size: 6,812 Bytes
e8e26ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
import os
import logging
import math
import time
import uuid
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.responses import StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
import io
import fitz  # PyMuPDF
import json
from transformers import pipeline
from typing import Iterator, Optional
import re

# Model name: default Vietnamese-optimized model with fallback for CPU usage on Hugging Face Free tier
MODEL_NAME = os.getenv("SUMMARIZER_MODEL_VI_VN", "VietAI/vit5-base-vietnamese")
# Optimized for CPU usage on Hugging Face Free tier
PRIMARY_VI_MODEL = MODEL_NAME
FALLBACK_MODEL = "google/mt5-small"

# Chunk and safety configuration (CPU-friendly), configurable via environment
CHUNK_WORDS = int(os.getenv("CHUNK_WORDS", "600"))  # smaller chunks to reduce per-chunk compute
MAX_CHUNKS = int(os.getenv("MAX_CHUNKS", "20"))       # safety limit to avoid long processing times

logger = logging.getLogger("pdf_summarizer")
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s")

app = FastAPI(title="PDF Summarizer with Streaming", version="0.1.0")

# CORS: allow all origins
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_methods=["*"],
    allow_headers=["*"],
)

# Summarizer instance loaded at startup, reused for all requests
summarizer = None
current_model_name = None

@app.on_event("startup")
def load_model():
    global summarizer, current_model_name
    model_to_load = PRIMARY_VI_MODEL
    current_model_name = model_to_load
    try:
        logger.info(f"Loading Vietnamese model for CPU: {model_to_load}")
        summarizer = pipeline("summarization", model=model_to_load)
        logger.info("Vietnamese model loaded successfully.")
    except Exception as e:
        logger.warning(f"Failed to load Vietnamese model ({model_to_load}) due to: {e}. Falling back to MT5-small.")
        current_model_name = FALLBACK_MODEL
        summarizer = pipeline("summarization", model=FALLBACK_MODEL)
        logger.info("Fallback model MT5-small loaded.")

def pdf_bytes_to_text(pdf_bytes: bytes) -> str:
    """
    Read text from a PDF provided as memory bytes without writing to disk.
    """
    doc = fitz.open(stream=pdf_bytes, filetype="pdf")
    texts = []
    for page in doc:
        text = page.get_text("text")
        if text:
            texts.append(text)
    doc.close()
    return "\n".join(texts)

def finalize_sentence(text: str) -> str:
    """
    Ensure the final sentence ends with punctuation; if not, try to cut at last punctuation or append a period.
    """
    t = text.strip()
    if not t:
        return t
    if t[-1] in ".!?":
        return t
    last_p = max(t.rfind("."), t.rfind("!"), t.rfind("?"))
    if last_p != -1 and last_p < len(t) - 1:
        return t[:last_p+1]
    return t + "."

def iter_summaries(text: str, length_ratio: float, request_id: Optional[str] = None) -> Iterator[tuple[int, str, float]]:
    """
    Chunk text into ~800-word blocks and yield a summary for each chunk.
    """
    WORDS_PER_CHUNK = 800
    words = text.split()
    chunks = [" ".join(words[i:i+WORDS_PER_CHUNK]) for i in range(0, len(words), WORDS_PER_CHUNK)]

    for idx, chunk in enumerate(chunks):
        chunk_word_count = len(chunk.split())
        # Length penalty scales with chunk size to balance brevity vs coverage
        lp = 0.5 + min(1.5, (chunk_word_count / 1000) * 1.5)

        # min_length and max_length proportional to chunk size and length_ratio
        min_len = max(20, int(chunk_word_count * 0.05 * length_ratio))
        max_len = max(min_len + 10, int(chunk_word_count * 0.25 * length_ratio))

        try:
            t0 = time.time()
            result = summarizer(
                chunk,
                min_length=min_len,
                max_length=max_len,
                length_penalty=lp,
                repetition_penalty=2.5,
                no_repeat_ngram_size=3,
                num_beams=4
            )
            duration = time.time() - t0
            summary = result[0]["summary_text"] if isinstance(result, list) else result["summary_text"]
        except Exception as e:
            summary = f"[summarization error: {str(e)}]"
            duration = 0.0

        summary = finalize_sentence(summary)
        yield idx, summary, duration

@app.post("/summarize")
async def summarize(pdf_file: UploadFile = File(...), length_ratio: float = 0.5):
    """
    Receive a PDF via memory (bytes) and return chunk-wise summaries as JSON Lines.
    """
    if pdf_file.content_type != "application/pdf":
        raise HTTPException(status_code=400, detail="Only PDF files are supported.")
    if not (0.1 <= length_ratio <= 1.0):
        raise HTTPException(status_code=400, detail="length_ratio must be between 0.1 and 1.0")

    pdf_bytes = await pdf_file.read()
    text = pdf_bytes_to_text(pdf_bytes)
    if not text.strip():
        raise HTTPException(status_code=400, detail="PDF contains no readable text.")
    # Safety guard: limit number of chunks to avoid long processing times on CPU/free tier
    total_words = len(text.split())
    chunk_count = math.ceil(total_words / CHUNK_WORDS) if CHUNK_WORDS > 0 else 1
    logger.info(f"Document text length: {total_words} words; chunks: {chunk_count}")
    if chunk_count > MAX_CHUNKS:
        raise HTTPException(
            status_code=400,
            detail=f"Document too long: requires {chunk_count} chunks (max {MAX_CHUNKS}). Please reduce the PDF size or length_ratio.",
        )

    # Per-request identifiers and timing for enhanced logging
    request_id = uuid.uuid4().hex
    start_time = time.time()
    logger.info(
        f"Request {request_id}: starting. words={total_words}, chunks={chunk_count}, model={current_model_name}"
    )

    def gen() -> Iterator[bytes]:
        durations = []
        for idx, summary, duration in iter_summaries(text, length_ratio, request_id):
            durations.append(duration)
            avg = sum(durations) / len(durations) if durations else 0.0
            remaining = max(0, chunk_count - idx - 1)
            est_sec = remaining * avg
            payload = {
                "request_id": request_id,
                "chunk": idx,
                "summary": summary,
                "estimate_seconds": round(est_sec, 2),
            }
            yield (json.dumps(payload) + "\n").encode("utf-8")
        # Finalize logging after streaming completes
        logger.info(
            f"Request {request_id} finished: chunks={chunk_count}, total_words={total_words}, model={current_model_name}, duration={time.time()-start_time:.2f}s"
        )

    return StreamingResponse(gen(), media_type="application/jsonlines")

@app.get("/health")
async def health():
    return {"status": "online"}