Spaces:
Runtime error
Runtime error
| import os | |
| import logging | |
| import math | |
| import time | |
| import uuid | |
| from fastapi import FastAPI, File, UploadFile, HTTPException | |
| from fastapi.responses import StreamingResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| import io | |
| import fitz # PyMuPDF | |
| import json | |
| from transformers import pipeline | |
| from typing import Iterator, Optional | |
| import re | |
| # Model name: default Vietnamese-optimized model with fallback for CPU usage on Hugging Face Free tier | |
| MODEL_NAME = os.getenv("SUMMARIZER_MODEL_VI_VN", "VietAI/vit5-base-vietnamese") | |
| # Optimized for CPU usage on Hugging Face Free tier | |
| PRIMARY_VI_MODEL = MODEL_NAME | |
| FALLBACK_MODEL = "google/mt5-small" | |
| # Chunk and safety configuration (CPU-friendly), configurable via environment | |
| CHUNK_WORDS = int(os.getenv("CHUNK_WORDS", "600")) # smaller chunks to reduce per-chunk compute | |
| MAX_CHUNKS = int(os.getenv("MAX_CHUNKS", "20")) # safety limit to avoid long processing times | |
| logger = logging.getLogger("pdf_summarizer") | |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s") | |
| app = FastAPI(title="PDF Summarizer with Streaming", version="0.1.0") | |
| # CORS: allow all origins | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Summarizer instance loaded at startup, reused for all requests | |
| summarizer = None | |
| current_model_name = None | |
| def load_model(): | |
| global summarizer, current_model_name | |
| model_to_load = PRIMARY_VI_MODEL | |
| current_model_name = model_to_load | |
| try: | |
| logger.info(f"Loading Vietnamese model for CPU: {model_to_load}") | |
| summarizer = pipeline("summarization", model=model_to_load) | |
| logger.info("Vietnamese model loaded successfully.") | |
| except Exception as e: | |
| logger.warning(f"Failed to load Vietnamese model ({model_to_load}) due to: {e}. Falling back to MT5-small.") | |
| current_model_name = FALLBACK_MODEL | |
| summarizer = pipeline("summarization", model=FALLBACK_MODEL) | |
| logger.info("Fallback model MT5-small loaded.") | |
| def pdf_bytes_to_text(pdf_bytes: bytes) -> str: | |
| """ | |
| Read text from a PDF provided as memory bytes without writing to disk. | |
| """ | |
| doc = fitz.open(stream=pdf_bytes, filetype="pdf") | |
| texts = [] | |
| for page in doc: | |
| text = page.get_text("text") | |
| if text: | |
| texts.append(text) | |
| doc.close() | |
| return "\n".join(texts) | |
| def finalize_sentence(text: str) -> str: | |
| """ | |
| Ensure the final sentence ends with punctuation; if not, try to cut at last punctuation or append a period. | |
| """ | |
| t = text.strip() | |
| if not t: | |
| return t | |
| if t[-1] in ".!?": | |
| return t | |
| last_p = max(t.rfind("."), t.rfind("!"), t.rfind("?")) | |
| if last_p != -1 and last_p < len(t) - 1: | |
| return t[:last_p+1] | |
| return t + "." | |
| def iter_summaries(text: str, length_ratio: float, request_id: Optional[str] = None) -> Iterator[tuple[int, str, float]]: | |
| """ | |
| Chunk text into ~800-word blocks and yield a summary for each chunk. | |
| """ | |
| WORDS_PER_CHUNK = 800 | |
| words = text.split() | |
| chunks = [" ".join(words[i:i+WORDS_PER_CHUNK]) for i in range(0, len(words), WORDS_PER_CHUNK)] | |
| for idx, chunk in enumerate(chunks): | |
| chunk_word_count = len(chunk.split()) | |
| # Length penalty scales with chunk size to balance brevity vs coverage | |
| lp = 0.5 + min(1.5, (chunk_word_count / 1000) * 1.5) | |
| # min_length and max_length proportional to chunk size and length_ratio | |
| min_len = max(20, int(chunk_word_count * 0.05 * length_ratio)) | |
| max_len = max(min_len + 10, int(chunk_word_count * 0.25 * length_ratio)) | |
| try: | |
| t0 = time.time() | |
| result = summarizer( | |
| chunk, | |
| min_length=min_len, | |
| max_length=max_len, | |
| length_penalty=lp, | |
| repetition_penalty=2.5, | |
| no_repeat_ngram_size=3, | |
| num_beams=4 | |
| ) | |
| duration = time.time() - t0 | |
| summary = result[0]["summary_text"] if isinstance(result, list) else result["summary_text"] | |
| except Exception as e: | |
| summary = f"[summarization error: {str(e)}]" | |
| duration = 0.0 | |
| summary = finalize_sentence(summary) | |
| yield idx, summary, duration | |
| async def summarize(pdf_file: UploadFile = File(...), length_ratio: float = 0.5): | |
| """ | |
| Receive a PDF via memory (bytes) and return chunk-wise summaries as JSON Lines. | |
| """ | |
| if pdf_file.content_type != "application/pdf": | |
| raise HTTPException(status_code=400, detail="Only PDF files are supported.") | |
| if not (0.1 <= length_ratio <= 1.0): | |
| raise HTTPException(status_code=400, detail="length_ratio must be between 0.1 and 1.0") | |
| pdf_bytes = await pdf_file.read() | |
| text = pdf_bytes_to_text(pdf_bytes) | |
| if not text.strip(): | |
| raise HTTPException(status_code=400, detail="PDF contains no readable text.") | |
| # Safety guard: limit number of chunks to avoid long processing times on CPU/free tier | |
| total_words = len(text.split()) | |
| chunk_count = math.ceil(total_words / CHUNK_WORDS) if CHUNK_WORDS > 0 else 1 | |
| logger.info(f"Document text length: {total_words} words; chunks: {chunk_count}") | |
| if chunk_count > MAX_CHUNKS: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Document too long: requires {chunk_count} chunks (max {MAX_CHUNKS}). Please reduce the PDF size or length_ratio.", | |
| ) | |
| # Per-request identifiers and timing for enhanced logging | |
| request_id = uuid.uuid4().hex | |
| start_time = time.time() | |
| logger.info( | |
| f"Request {request_id}: starting. words={total_words}, chunks={chunk_count}, model={current_model_name}" | |
| ) | |
| def gen() -> Iterator[bytes]: | |
| durations = [] | |
| for idx, summary, duration in iter_summaries(text, length_ratio, request_id): | |
| durations.append(duration) | |
| avg = sum(durations) / len(durations) if durations else 0.0 | |
| remaining = max(0, chunk_count - idx - 1) | |
| est_sec = remaining * avg | |
| payload = { | |
| "request_id": request_id, | |
| "chunk": idx, | |
| "summary": summary, | |
| "estimate_seconds": round(est_sec, 2), | |
| } | |
| yield (json.dumps(payload) + "\n").encode("utf-8") | |
| # Finalize logging after streaming completes | |
| logger.info( | |
| f"Request {request_id} finished: chunks={chunk_count}, total_words={total_words}, model={current_model_name}, duration={time.time()-start_time:.2f}s" | |
| ) | |
| return StreamingResponse(gen(), media_type="application/jsonlines") | |
| async def health(): | |
| return {"status": "online"} | |