aleespace / app.py
aliMohammad16's picture
Update app.py
dc08424 verified
raw
history blame
847 Bytes
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
app = FastAPI()
# Load Model & Tokenizer
MODEL_NAME = "facebook/bart-large-cnn" # Small & fast summarization model
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME).to("cpu") # Use "cuda" if you have a GPU
# Define input format
class InputText(BaseModel):
text: str
@app.post("/summarize")
async def summarize_text(input_text: InputText):
inputs = tokenizer(input_text.text, return_tensors="pt", max_length=1024, truncation=True)
summary_ids = model.generate(inputs.input_ids, max_length=150, min_length=50, length_penalty=2.0)
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
return {"summary": summary}