Spaces:
Sleeping
Sleeping
| """FastAPI service for BART-based text summarization. Exposes single and batch summarize endpoints.""" | |
| import logging | |
| import os | |
| import torch | |
| import sentry_sdk | |
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| from contextlib import asynccontextmanager | |
| from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| # Monitor errors using Sentry | |
| sentry_sdk.init( | |
| dsn=os.environ.get("SENTRY_DSN"), | |
| send_default_pii=True, | |
| ) | |
| sentry_sdk.set_tag("service", "api") | |
| logger = logging.getLogger(__name__) | |
| # 1. Define Request Schemas | |
| class SummarizeRequest(BaseModel): | |
| """Single summarization request: input text and optional max summary length (tokens).""" | |
| text: str | |
| max_length: int = 80 | |
| class SummarizeBatchRequest(BaseModel): | |
| """List of items to summarize in one batched forward pass. Each item has its own max_length.""" | |
| items: list[SummarizeRequest] | |
| # 2. Limits and generate config | |
| MAX_BATCH_SIZE = 16 | |
| MAX_SUMMARY_TOKENS = 60 | |
| GENERATE_CONFIG = { | |
| "num_beams": 4, | |
| "min_length": 30, | |
| "no_repeat_ngram_size": 2, | |
| "length_penalty": 2.0, | |
| "repetition_penalty": 2.5, | |
| "early_stopping": True, | |
| } | |
| ml_models = {} | |
| async def lifespan(app: FastAPI): | |
| """Load model and tokenizer on startup (GPU if available), clear on shutdown.""" | |
| # DYNAMIC DEVICE DETECTION | |
| # Checks for NVIDIA GPU, otherwise defaults to CPU | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"--- System Check: Using {device.upper()} ---") | |
| if device == "cuda": | |
| print(f"--- GPU Name: {torch.cuda.get_device_name(0)} ---") | |
| # LOAD MODEL (local directory; tokenizer from same path for offline/consistency) | |
| model_id = "janrudolf/bart-finetuned-wobbly-bush-56-h9vtgsw1" | |
| ml_models["tokenizer"] = AutoTokenizer.from_pretrained(model_id) | |
| ml_models["model"] = AutoModelForSeq2SeqLM.from_pretrained(model_id).to(device) | |
| ml_models["device"] = device | |
| yield | |
| # Cleanup | |
| ml_models.clear() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| app = FastAPI(lifespan=lifespan) | |
| # 3. API Endpoints | |
| def health(): | |
| """Return service readiness, device (cpu/cuda), and PyTorch version.""" | |
| return { | |
| "status": "ready", | |
| "device": ml_models.get("device", "unknown"), | |
| "torch_version": torch.__version__, | |
| } | |
| async def summarize(request: SummarizeRequest): | |
| """Summarize a single text. Returns summary and device used. Raises 507 on OOM.""" | |
| try: | |
| model = ml_models["model"] | |
| tokenizer = ml_models["tokenizer"] | |
| device = ml_models["device"] | |
| # Tokenize like main_test_finetuned_model.py (max_length + truncation) | |
| inputs = tokenizer( | |
| request.text, | |
| return_tensors="pt", | |
| max_length=1024, | |
| truncation=True, | |
| ) | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| summary_ids = model.generate( | |
| inputs["input_ids"], | |
| max_length=min(request.max_length, MAX_SUMMARY_TOKENS), | |
| **GENERATE_CONFIG, | |
| ) | |
| # Seq2Seq output is the summary only (no input tokens) | |
| summary_text = tokenizer.decode( | |
| summary_ids[0], skip_special_tokens=True | |
| ).strip() | |
| return {"summary": summary_text, "used_device": device} | |
| except torch.cuda.OutOfMemoryError: | |
| raise HTTPException(status_code=507, detail="GPU Memory Full") | |
| except Exception as e: | |
| logger.exception("Summarize failed") | |
| sentry_sdk.capture_exception(e) | |
| raise HTTPException(status_code=500, detail="Internal server error") | |
| async def summarize_batch(request: SummarizeBatchRequest): | |
| """Summarize multiple texts in one batched forward pass. Returns summaries in same order as items. Max 16 items per batch; 422 if exceeded. Raises 507 on OOM.""" | |
| if not request.items: | |
| return {"summaries": [], "used_device": ml_models.get("device", "unknown")} | |
| if len(request.items) > MAX_BATCH_SIZE: | |
| raise HTTPException( | |
| status_code=422, | |
| detail=f"Batch size {len(request.items)} exceeds maximum of {MAX_BATCH_SIZE}", | |
| ) | |
| try: | |
| model = ml_models["model"] | |
| tokenizer = ml_models["tokenizer"] | |
| device = ml_models["device"] | |
| texts = [item.text for item in request.items] | |
| max_lengths = [ | |
| min(item.max_length, MAX_SUMMARY_TOKENS) for item in request.items | |
| ] | |
| # tokenize batch: padding to same length, truncation | |
| inputs = tokenizer( | |
| texts, | |
| return_tensors="pt", | |
| max_length=1024, | |
| truncation=True, | |
| padding=True, | |
| return_attention_mask=True, | |
| ) | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| max_gen = max(max_lengths) | |
| with torch.no_grad(): | |
| summary_ids = model.generate( | |
| inputs["input_ids"], | |
| attention_mask=inputs["attention_mask"], | |
| max_length=max_gen, | |
| **GENERATE_CONFIG, | |
| ) | |
| # decode each sequence, truncating to per-item max_length (in tokens) | |
| summaries = [] | |
| for i, ids in enumerate(summary_ids): | |
| cap = max_lengths[i] | |
| ids_trunc = ids[:cap] | |
| summary_text = tokenizer.decode(ids_trunc, skip_special_tokens=True).strip() | |
| summaries.append(summary_text) | |
| return {"summaries": summaries, "used_device": device} | |
| except torch.cuda.OutOfMemoryError: | |
| raise HTTPException(status_code=507, detail="GPU Memory Full") | |
| except Exception as e: | |
| logger.exception("Summarize batch failed") | |
| sentry_sdk.capture_exception(e) | |
| raise HTTPException(status_code=500, detail="Internal server error") | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="127.0.0.1", port=8000) | |