text-summarizer / api.py
janrudolf's picture
Use HF uploaded model
f6cbcb4
"""FastAPI service for BART-based text summarization. Exposes single and batch summarize endpoints."""
import logging
import os
import torch
import sentry_sdk
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from contextlib import asynccontextmanager
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from dotenv import load_dotenv
load_dotenv()
# Monitor errors using Sentry
sentry_sdk.init(
dsn=os.environ.get("SENTRY_DSN"),
send_default_pii=True,
)
sentry_sdk.set_tag("service", "api")
logger = logging.getLogger(__name__)
# 1. Define Request Schemas
class SummarizeRequest(BaseModel):
"""Single summarization request: input text and optional max summary length (tokens)."""
text: str
max_length: int = 80
class SummarizeBatchRequest(BaseModel):
"""List of items to summarize in one batched forward pass. Each item has its own max_length."""
items: list[SummarizeRequest]
# 2. Limits and generate config
MAX_BATCH_SIZE = 16
MAX_SUMMARY_TOKENS = 60
GENERATE_CONFIG = {
"num_beams": 4,
"min_length": 30,
"no_repeat_ngram_size": 2,
"length_penalty": 2.0,
"repetition_penalty": 2.5,
"early_stopping": True,
}
ml_models = {}
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Load model and tokenizer on startup (GPU if available), clear on shutdown."""
# DYNAMIC DEVICE DETECTION
# Checks for NVIDIA GPU, otherwise defaults to CPU
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"--- System Check: Using {device.upper()} ---")
if device == "cuda":
print(f"--- GPU Name: {torch.cuda.get_device_name(0)} ---")
# LOAD MODEL (local directory; tokenizer from same path for offline/consistency)
model_id = "janrudolf/bart-finetuned-wobbly-bush-56-h9vtgsw1"
ml_models["tokenizer"] = AutoTokenizer.from_pretrained(model_id)
ml_models["model"] = AutoModelForSeq2SeqLM.from_pretrained(model_id).to(device)
ml_models["device"] = device
yield
# Cleanup
ml_models.clear()
if torch.cuda.is_available():
torch.cuda.empty_cache()
app = FastAPI(lifespan=lifespan)
# 3. API Endpoints
@app.get("/health")
def health():
"""Return service readiness, device (cpu/cuda), and PyTorch version."""
return {
"status": "ready",
"device": ml_models.get("device", "unknown"),
"torch_version": torch.__version__,
}
@app.post("/summarize")
async def summarize(request: SummarizeRequest):
"""Summarize a single text. Returns summary and device used. Raises 507 on OOM."""
try:
model = ml_models["model"]
tokenizer = ml_models["tokenizer"]
device = ml_models["device"]
# Tokenize like main_test_finetuned_model.py (max_length + truncation)
inputs = tokenizer(
request.text,
return_tensors="pt",
max_length=1024,
truncation=True,
)
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
summary_ids = model.generate(
inputs["input_ids"],
max_length=min(request.max_length, MAX_SUMMARY_TOKENS),
**GENERATE_CONFIG,
)
# Seq2Seq output is the summary only (no input tokens)
summary_text = tokenizer.decode(
summary_ids[0], skip_special_tokens=True
).strip()
return {"summary": summary_text, "used_device": device}
except torch.cuda.OutOfMemoryError:
raise HTTPException(status_code=507, detail="GPU Memory Full")
except Exception as e:
logger.exception("Summarize failed")
sentry_sdk.capture_exception(e)
raise HTTPException(status_code=500, detail="Internal server error")
@app.post("/summarize/batch")
async def summarize_batch(request: SummarizeBatchRequest):
"""Summarize multiple texts in one batched forward pass. Returns summaries in same order as items. Max 16 items per batch; 422 if exceeded. Raises 507 on OOM."""
if not request.items:
return {"summaries": [], "used_device": ml_models.get("device", "unknown")}
if len(request.items) > MAX_BATCH_SIZE:
raise HTTPException(
status_code=422,
detail=f"Batch size {len(request.items)} exceeds maximum of {MAX_BATCH_SIZE}",
)
try:
model = ml_models["model"]
tokenizer = ml_models["tokenizer"]
device = ml_models["device"]
texts = [item.text for item in request.items]
max_lengths = [
min(item.max_length, MAX_SUMMARY_TOKENS) for item in request.items
]
# tokenize batch: padding to same length, truncation
inputs = tokenizer(
texts,
return_tensors="pt",
max_length=1024,
truncation=True,
padding=True,
return_attention_mask=True,
)
inputs = {k: v.to(device) for k, v in inputs.items()}
max_gen = max(max_lengths)
with torch.no_grad():
summary_ids = model.generate(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
max_length=max_gen,
**GENERATE_CONFIG,
)
# decode each sequence, truncating to per-item max_length (in tokens)
summaries = []
for i, ids in enumerate(summary_ids):
cap = max_lengths[i]
ids_trunc = ids[:cap]
summary_text = tokenizer.decode(ids_trunc, skip_special_tokens=True).strip()
summaries.append(summary_text)
return {"summaries": summaries, "used_device": device}
except torch.cuda.OutOfMemoryError:
raise HTTPException(status_code=507, detail="GPU Memory Full")
except Exception as e:
logger.exception("Summarize batch failed")
sentry_sdk.capture_exception(e)
raise HTTPException(status_code=500, detail="Internal server error")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="127.0.0.1", port=8000)