"""FastAPI service for BART-based text summarization. Exposes single and batch summarize endpoints.""" import logging import os import torch import sentry_sdk from fastapi import FastAPI, HTTPException from pydantic import BaseModel from contextlib import asynccontextmanager from transformers import AutoModelForSeq2SeqLM, AutoTokenizer from dotenv import load_dotenv load_dotenv() # Monitor errors using Sentry sentry_sdk.init( dsn=os.environ.get("SENTRY_DSN"), send_default_pii=True, ) sentry_sdk.set_tag("service", "api") logger = logging.getLogger(__name__) # 1. Define Request Schemas class SummarizeRequest(BaseModel): """Single summarization request: input text and optional max summary length (tokens).""" text: str max_length: int = 80 class SummarizeBatchRequest(BaseModel): """List of items to summarize in one batched forward pass. Each item has its own max_length.""" items: list[SummarizeRequest] # 2. Limits and generate config MAX_BATCH_SIZE = 16 MAX_SUMMARY_TOKENS = 60 GENERATE_CONFIG = { "num_beams": 4, "min_length": 30, "no_repeat_ngram_size": 2, "length_penalty": 2.0, "repetition_penalty": 2.5, "early_stopping": True, } ml_models = {} @asynccontextmanager async def lifespan(app: FastAPI): """Load model and tokenizer on startup (GPU if available), clear on shutdown.""" # DYNAMIC DEVICE DETECTION # Checks for NVIDIA GPU, otherwise defaults to CPU device = "cuda" if torch.cuda.is_available() else "cpu" print(f"--- System Check: Using {device.upper()} ---") if device == "cuda": print(f"--- GPU Name: {torch.cuda.get_device_name(0)} ---") # LOAD MODEL (local directory; tokenizer from same path for offline/consistency) model_id = "janrudolf/bart-finetuned-wobbly-bush-56-h9vtgsw1" ml_models["tokenizer"] = AutoTokenizer.from_pretrained(model_id) ml_models["model"] = AutoModelForSeq2SeqLM.from_pretrained(model_id).to(device) ml_models["device"] = device yield # Cleanup ml_models.clear() if torch.cuda.is_available(): torch.cuda.empty_cache() app = FastAPI(lifespan=lifespan) # 3. API Endpoints @app.get("/health") def health(): """Return service readiness, device (cpu/cuda), and PyTorch version.""" return { "status": "ready", "device": ml_models.get("device", "unknown"), "torch_version": torch.__version__, } @app.post("/summarize") async def summarize(request: SummarizeRequest): """Summarize a single text. Returns summary and device used. Raises 507 on OOM.""" try: model = ml_models["model"] tokenizer = ml_models["tokenizer"] device = ml_models["device"] # Tokenize like main_test_finetuned_model.py (max_length + truncation) inputs = tokenizer( request.text, return_tensors="pt", max_length=1024, truncation=True, ) inputs = {k: v.to(device) for k, v in inputs.items()} with torch.no_grad(): summary_ids = model.generate( inputs["input_ids"], max_length=min(request.max_length, MAX_SUMMARY_TOKENS), **GENERATE_CONFIG, ) # Seq2Seq output is the summary only (no input tokens) summary_text = tokenizer.decode( summary_ids[0], skip_special_tokens=True ).strip() return {"summary": summary_text, "used_device": device} except torch.cuda.OutOfMemoryError: raise HTTPException(status_code=507, detail="GPU Memory Full") except Exception as e: logger.exception("Summarize failed") sentry_sdk.capture_exception(e) raise HTTPException(status_code=500, detail="Internal server error") @app.post("/summarize/batch") async def summarize_batch(request: SummarizeBatchRequest): """Summarize multiple texts in one batched forward pass. Returns summaries in same order as items. Max 16 items per batch; 422 if exceeded. Raises 507 on OOM.""" if not request.items: return {"summaries": [], "used_device": ml_models.get("device", "unknown")} if len(request.items) > MAX_BATCH_SIZE: raise HTTPException( status_code=422, detail=f"Batch size {len(request.items)} exceeds maximum of {MAX_BATCH_SIZE}", ) try: model = ml_models["model"] tokenizer = ml_models["tokenizer"] device = ml_models["device"] texts = [item.text for item in request.items] max_lengths = [ min(item.max_length, MAX_SUMMARY_TOKENS) for item in request.items ] # tokenize batch: padding to same length, truncation inputs = tokenizer( texts, return_tensors="pt", max_length=1024, truncation=True, padding=True, return_attention_mask=True, ) inputs = {k: v.to(device) for k, v in inputs.items()} max_gen = max(max_lengths) with torch.no_grad(): summary_ids = model.generate( inputs["input_ids"], attention_mask=inputs["attention_mask"], max_length=max_gen, **GENERATE_CONFIG, ) # decode each sequence, truncating to per-item max_length (in tokens) summaries = [] for i, ids in enumerate(summary_ids): cap = max_lengths[i] ids_trunc = ids[:cap] summary_text = tokenizer.decode(ids_trunc, skip_special_tokens=True).strip() summaries.append(summary_text) return {"summaries": summaries, "used_device": device} except torch.cuda.OutOfMemoryError: raise HTTPException(status_code=507, detail="GPU Memory Full") except Exception as e: logger.exception("Summarize batch failed") sentry_sdk.capture_exception(e) raise HTTPException(status_code=500, detail="Internal server error") if __name__ == "__main__": import uvicorn uvicorn.run(app, host="127.0.0.1", port=8000)