Spaces:
Sleeping
Sleeping
File size: 6,087 Bytes
8c73f65 e5f9579 be3904a e5f9579 be3904a e5f9579 be3904a 8c73f65 be3904a 8c73f65 be3904a 8c73f65 be3904a 8c73f65 be3904a 8c73f65 be3904a e5f9579 f6cbcb4 e5f9579 be3904a 8c73f65 be3904a 8c73f65 be3904a 8c73f65 be3904a 8c73f65 be3904a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 | """FastAPI service for BART-based text summarization. Exposes single and batch summarize endpoints."""
import logging
import os
import torch
import sentry_sdk
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from contextlib import asynccontextmanager
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from dotenv import load_dotenv
load_dotenv()
# Monitor errors using Sentry
sentry_sdk.init(
dsn=os.environ.get("SENTRY_DSN"),
send_default_pii=True,
)
sentry_sdk.set_tag("service", "api")
logger = logging.getLogger(__name__)
# 1. Define Request Schemas
class SummarizeRequest(BaseModel):
"""Single summarization request: input text and optional max summary length (tokens)."""
text: str
max_length: int = 80
class SummarizeBatchRequest(BaseModel):
"""List of items to summarize in one batched forward pass. Each item has its own max_length."""
items: list[SummarizeRequest]
# 2. Limits and generate config
MAX_BATCH_SIZE = 16
MAX_SUMMARY_TOKENS = 60
GENERATE_CONFIG = {
"num_beams": 4,
"min_length": 30,
"no_repeat_ngram_size": 2,
"length_penalty": 2.0,
"repetition_penalty": 2.5,
"early_stopping": True,
}
ml_models = {}
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Load model and tokenizer on startup (GPU if available), clear on shutdown."""
# DYNAMIC DEVICE DETECTION
# Checks for NVIDIA GPU, otherwise defaults to CPU
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"--- System Check: Using {device.upper()} ---")
if device == "cuda":
print(f"--- GPU Name: {torch.cuda.get_device_name(0)} ---")
# LOAD MODEL (local directory; tokenizer from same path for offline/consistency)
model_id = "janrudolf/bart-finetuned-wobbly-bush-56-h9vtgsw1"
ml_models["tokenizer"] = AutoTokenizer.from_pretrained(model_id)
ml_models["model"] = AutoModelForSeq2SeqLM.from_pretrained(model_id).to(device)
ml_models["device"] = device
yield
# Cleanup
ml_models.clear()
if torch.cuda.is_available():
torch.cuda.empty_cache()
app = FastAPI(lifespan=lifespan)
# 3. API Endpoints
@app.get("/health")
def health():
"""Return service readiness, device (cpu/cuda), and PyTorch version."""
return {
"status": "ready",
"device": ml_models.get("device", "unknown"),
"torch_version": torch.__version__,
}
@app.post("/summarize")
async def summarize(request: SummarizeRequest):
"""Summarize a single text. Returns summary and device used. Raises 507 on OOM."""
try:
model = ml_models["model"]
tokenizer = ml_models["tokenizer"]
device = ml_models["device"]
# Tokenize like main_test_finetuned_model.py (max_length + truncation)
inputs = tokenizer(
request.text,
return_tensors="pt",
max_length=1024,
truncation=True,
)
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
summary_ids = model.generate(
inputs["input_ids"],
max_length=min(request.max_length, MAX_SUMMARY_TOKENS),
**GENERATE_CONFIG,
)
# Seq2Seq output is the summary only (no input tokens)
summary_text = tokenizer.decode(
summary_ids[0], skip_special_tokens=True
).strip()
return {"summary": summary_text, "used_device": device}
except torch.cuda.OutOfMemoryError:
raise HTTPException(status_code=507, detail="GPU Memory Full")
except Exception as e:
logger.exception("Summarize failed")
sentry_sdk.capture_exception(e)
raise HTTPException(status_code=500, detail="Internal server error")
@app.post("/summarize/batch")
async def summarize_batch(request: SummarizeBatchRequest):
"""Summarize multiple texts in one batched forward pass. Returns summaries in same order as items. Max 16 items per batch; 422 if exceeded. Raises 507 on OOM."""
if not request.items:
return {"summaries": [], "used_device": ml_models.get("device", "unknown")}
if len(request.items) > MAX_BATCH_SIZE:
raise HTTPException(
status_code=422,
detail=f"Batch size {len(request.items)} exceeds maximum of {MAX_BATCH_SIZE}",
)
try:
model = ml_models["model"]
tokenizer = ml_models["tokenizer"]
device = ml_models["device"]
texts = [item.text for item in request.items]
max_lengths = [
min(item.max_length, MAX_SUMMARY_TOKENS) for item in request.items
]
# tokenize batch: padding to same length, truncation
inputs = tokenizer(
texts,
return_tensors="pt",
max_length=1024,
truncation=True,
padding=True,
return_attention_mask=True,
)
inputs = {k: v.to(device) for k, v in inputs.items()}
max_gen = max(max_lengths)
with torch.no_grad():
summary_ids = model.generate(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
max_length=max_gen,
**GENERATE_CONFIG,
)
# decode each sequence, truncating to per-item max_length (in tokens)
summaries = []
for i, ids in enumerate(summary_ids):
cap = max_lengths[i]
ids_trunc = ids[:cap]
summary_text = tokenizer.decode(ids_trunc, skip_special_tokens=True).strip()
summaries.append(summary_text)
return {"summaries": summaries, "used_device": device}
except torch.cuda.OutOfMemoryError:
raise HTTPException(status_code=507, detail="GPU Memory Full")
except Exception as e:
logger.exception("Summarize batch failed")
sentry_sdk.capture_exception(e)
raise HTTPException(status_code=500, detail="Internal server error")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="127.0.0.1", port=8000)
|